blob: 571456c70d8d0b71381e12cdecf24f3bb3f47c4c [file] [log] [blame]
// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote;
import build.bazel.remote.execution.v2.ExecutionGrpc;
import com.google.devtools.build.lib.remote.common.NetworkTime;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.util.function.Supplier;
/** The ClientInterceptor used to track network time. */
public class NetworkTimeInterceptor implements ClientInterceptor {
private final Supplier<NetworkTime> networkTimeSupplier;
public NetworkTimeInterceptor(Supplier<NetworkTime> networkTimeSupplier) {
this.networkTimeSupplier = networkTimeSupplier;
}
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
// prevent accounting for execution wait time
if (method != ExecutionGrpc.getExecuteMethod()
&& method != ExecutionGrpc.getWaitExecutionMethod()) {
NetworkTime networkTime = networkTimeSupplier.get();
if (networkTime != null) {
call = new NetworkTimeCall<>(call, networkTime);
}
}
return call;
}
private static class NetworkTimeCall<ReqT, RespT>
extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT> {
private final NetworkTime networkTime;
private boolean firstMessage = true;
protected NetworkTimeCall(ClientCall<ReqT, RespT> delegate, NetworkTime networkTime) {
super(delegate);
this.networkTime = networkTime;
}
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
super.start(
new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(
responseListener) {
@Override
public void onClose(Status status, Metadata trailers) {
try {
networkTime.stop();
} catch (RuntimeException e) {
// An unchecked exception means we have bugs in the above try block, force crash
// Bazel so we can have a chance to look into.
throw new AssertionError(
"networkTime.stop() must not throw unchecked exception: " + networkTime, e);
} finally {
// Make sure to call super.onClose, otherwise gRPC will silently hang indefinitely.
// See https://github.com/grpc/grpc-java/pull/6107.
super.onClose(status, trailers);
}
}
},
headers);
}
@Override
public void sendMessage(ReqT message) {
if (firstMessage) {
networkTime.start();
firstMessage = false;
}
super.sendMessage(message);
}
}
}