Intercept capabilities and uploader requests and add custom grpc headers
Following #10015. Some requests do not use the custom headers.
Closes #10634.
PiperOrigin-RevId: 298574179
diff --git a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
index d709cfe..52c81cd 100644
--- a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
+++ b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
@@ -54,7 +54,7 @@
String target,
String proxy,
AuthAndTLSOptions options,
- @Nullable ClientInterceptor interceptor)
+ @Nullable List<ClientInterceptor> interceptors)
throws IOException {
Preconditions.checkNotNull(target);
Preconditions.checkNotNull(options);
@@ -69,8 +69,8 @@
newNettyChannelBuilder(targetUrl, proxy)
.negotiationType(
isTlsEnabled(target) ? NegotiationType.TLS : NegotiationType.PLAINTEXT);
- if (interceptor != null) {
- builder.intercept(interceptor);
+ if (interceptors != null) {
+ builder.intercept(interceptors);
}
if (sslContext != null) {
builder.sslContext(sslContext);
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index 1750526..c53186a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -122,7 +122,6 @@
private ContentAddressableStorageFutureStub casFutureStub() {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
@@ -130,7 +129,6 @@
private ByteStreamStub bsAsyncStub() {
return ByteStreamGrpc.newStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
@@ -138,7 +136,6 @@
private ActionCacheBlockingStub acBlockingStub() {
return ActionCacheGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
@@ -146,7 +143,6 @@
private ActionCacheFutureStub acFutureStub() {
return ActionCacheGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
.withCallCredentials(credentials)
.withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
index 306643b..52fc07f 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -59,7 +59,6 @@
private ExecutionBlockingStub execBlockingStub() {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options))
.withCallCredentials(callCredentials);
}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
index ff00eb6..34b0bfe 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
@@ -32,6 +32,7 @@
import io.netty.channel.unix.DomainSocketAddress;
import java.io.IOException;
import java.net.URI;
+import java.util.List;
import javax.annotation.Nullable;
/**
@@ -59,10 +60,10 @@
String target,
String proxyUri,
AuthAndTLSOptions authOptions,
- @Nullable ClientInterceptor interceptor)
+ @Nullable List<ClientInterceptor> interceptors)
throws IOException {
return new ReferenceCountedChannel(
- GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptor));
+ GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors));
}
public static RemoteCacheClient create(
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
index 26e54f3..7a342f1 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -107,6 +107,37 @@
return !Strings.isNullOrEmpty(options.remoteExecutor);
}
+ private void verifyServerCapabilities(
+ RemoteOptions remoteOptions,
+ ReferenceCountedChannel channel,
+ CallCredentials credentials,
+ RemoteRetrier retrier,
+ CommandEnvironment env,
+ DigestUtil digestUtil)
+ throws AbruptExitException {
+ RemoteServerCapabilities rsc =
+ new RemoteServerCapabilities(
+ remoteOptions.remoteInstanceName,
+ channel,
+ credentials,
+ remoteOptions.remoteTimeout,
+ retrier);
+ ServerCapabilities capabilities = null;
+ try {
+ capabilities = rsc.get(env.getCommandId().toString(), env.getBuildRequestId());
+ } catch (IOException e) {
+ throw new AbruptExitException(
+ "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
+ ExitCode.REMOTE_ERROR,
+ e);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ return;
+ }
+ checkClientServerCompatibility(
+ capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());
+ }
+
@Override
public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
Preconditions.checkState(actionContextProvider == null, "actionContextProvider must be null");
@@ -178,12 +209,17 @@
ReferenceCountedChannel execChannel = null;
ReferenceCountedChannel cacheChannel = null;
if (enableRemoteExecution) {
+ ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
+ interceptors.add(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions));
+ if (loggingInterceptor != null) {
+ interceptors.add(loggingInterceptor);
+ }
execChannel =
RemoteCacheClientFactory.createGrpcChannel(
remoteOptions.remoteExecutor,
remoteOptions.remoteProxy,
authAndTlsOptions,
- loggingInterceptor);
+ interceptors.build());
// Create a separate channel if --remote_executor and --remote_cache point to different
// endpoints.
if (Strings.isNullOrEmpty(remoteOptions.remoteCache)
@@ -193,12 +229,17 @@
}
if (cacheChannel == null) {
+ ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
+ interceptors.add(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions));
+ if (loggingInterceptor != null) {
+ interceptors.add(loggingInterceptor);
+ }
cacheChannel =
RemoteCacheClientFactory.createGrpcChannel(
remoteOptions.remoteCache,
remoteOptions.remoteProxy,
authAndTlsOptions,
- loggingInterceptor);
+ interceptors.build());
}
CallCredentials credentials = GoogleAuthUtils.newCallCredentials(authAndTlsOptions);
@@ -212,27 +253,13 @@
// We always query the execution server for capabilities, if it is defined. A remote
// execution/cache system should have all its servers to return the capabilities pertaining
// to the system as a whole.
- RemoteServerCapabilities rsc =
- new RemoteServerCapabilities(
- remoteOptions.remoteInstanceName,
- (execChannel != null ? execChannel : cacheChannel),
- credentials,
- remoteOptions.remoteTimeout,
- retrier);
- ServerCapabilities capabilities = null;
- try {
- capabilities = rsc.get(buildRequestId, invocationId);
- } catch (IOException e) {
- throw new AbruptExitException(
- "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
- ExitCode.REMOTE_ERROR,
- e);
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- return;
+ if (execChannel != null) {
+ verifyServerCapabilities(remoteOptions, execChannel, credentials, retrier, env, digestUtil);
}
- checkClientServerCompatibility(
- capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());
+ if (cacheChannel != execChannel) {
+ verifyServerCapabilities(
+ remoteOptions, cacheChannel, credentials, retrier, env, digestUtil);
+ }
ByteStreamUploader uploader =
new ByteStreamUploader(
@@ -241,6 +268,7 @@
credentials,
remoteOptions.remoteTimeout,
retrier);
+
cacheChannel.release();
RemoteCacheClient cacheClient =
new GrpcCacheClient(
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index 2b47d94..8592412 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -46,6 +46,7 @@
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status;
@@ -53,6 +54,7 @@
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.util.MutableHandlerRegistry;
import java.io.ByteArrayInputStream;
@@ -690,6 +692,74 @@
}
@Test
+ public void customHeadersAreAttachedToRequest() throws Exception {
+ RemoteRetrier retrier =
+ TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);
+
+ Metadata metadata = new Metadata();
+ metadata.put(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER), "Value1");
+ metadata.put(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER), "Value2");
+
+ ByteStreamUploader uploader =
+ new ByteStreamUploader(
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(
+ InProcessChannelBuilder.forName("Server for " + this.getClass())
+ .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
+ .build()),
+ null, /* timeout seconds */
+ 60,
+ retrier);
+
+ byte[] blob = new byte[CHUNK_SIZE];
+ Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+ HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+ serviceRegistry.addService(
+ ServerInterceptors.intercept(
+ new ByteStreamImplBase() {
+ @Override
+ public StreamObserver<WriteRequest> write(
+ StreamObserver<WriteResponse> streamObserver) {
+ return new StreamObserver<WriteRequest>() {
+ @Override
+ public void onNext(WriteRequest writeRequest) {}
+
+ @Override
+ public void onError(Throwable throwable) {
+ fail("onError should never be called.");
+ }
+
+ @Override
+ public void onCompleted() {
+ WriteResponse response =
+ WriteResponse.newBuilder().setCommittedSize(blob.length).build();
+ streamObserver.onNext(response);
+ streamObserver.onCompleted();
+ }
+ };
+ }
+ },
+ new ServerInterceptor() {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call,
+ Metadata metadata,
+ ServerCallHandler<ReqT, RespT> next) {
+ assertThat(metadata.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
+ .isEqualTo("Value1");
+ assertThat(metadata.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
+ .isEqualTo("Value2");
+ assertThat(metadata.get(Metadata.Key.of("Key3", Metadata.ASCII_STRING_MARSHALLER)))
+ .isEqualTo(null);
+ return next.startCall(call, metadata);
+ }
+ }));
+
+ uploader.uploadBlob(hash, chunker, true);
+ }
+
+ @Test
public void sameBlobShouldNotBeUploadedTwice() throws Exception {
// Test that uploading the same file concurrently triggers only one file upload.
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
index 9c551d8..240947d 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
@@ -226,6 +226,7 @@
InProcessChannelBuilder.forName(fakeServerName)
.directExecutor()
.intercept(new CallCredentialsInterceptor(creds))
+ .intercept(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions))
.build());
ByteStreamUploader uploader =
new ByteStreamUploader(
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
index 43fec08..fab371c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
@@ -233,7 +233,10 @@
retryService);
ReferenceCountedChannel channel =
new ReferenceCountedChannel(
- InProcessChannelBuilder.forName(fakeServerName).directExecutor().build());
+ InProcessChannelBuilder.forName(fakeServerName)
+ .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
+ .directExecutor()
+ .build());
GrpcRemoteExecutor executor =
new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
CallCredentials creds =
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
index 2c47b44..357446c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
@@ -25,6 +25,8 @@
import build.bazel.remote.execution.v2.PriorityCapabilities.PriorityRange;
import build.bazel.remote.execution.v2.RequestMetadata;
import build.bazel.remote.execution.v2.ServerCapabilities;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
@@ -107,6 +109,60 @@
}
}
+ private static class RequestCustomHeadersValidator implements ServerInterceptor {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
+ assertThat(headers.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
+ .isEqualTo("Value1");
+ assertThat(headers.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
+ .isEqualTo("Value2");
+ return next.startCall(call, headers);
+ }
+ }
+
+ @Test
+ public void testCustomHeadersAreAttached() throws Exception {
+ ServerCapabilities caps =
+ ServerCapabilities.newBuilder()
+ .setExecutionCapabilities(
+ ExecutionCapabilities.newBuilder().setExecEnabled(true).build())
+ .build();
+ serviceRegistry.addService(
+ ServerInterceptors.intercept(
+ new CapabilitiesImplBase() {
+ @Override
+ public void getCapabilities(
+ GetCapabilitiesRequest request,
+ StreamObserver<ServerCapabilities> responseObserver) {
+ responseObserver.onNext(caps);
+ responseObserver.onCompleted();
+ }
+ },
+ new RequestCustomHeadersValidator()));
+
+ RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);
+ remoteOptions.remoteHeaders =
+ ImmutableList.of(
+ Maps.immutableEntry("Key1", "Value1"), Maps.immutableEntry("Key2", "Value2"));
+
+ RemoteRetrier retrier =
+ TestUtils.newRemoteRetrier(
+ () -> new ExponentialBackoff(remoteOptions),
+ RemoteRetrier.RETRIABLE_GRPC_ERRORS,
+ retryService);
+ ReferenceCountedChannel channel =
+ new ReferenceCountedChannel(
+ InProcessChannelBuilder.forName(fakeServerName)
+ .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
+ .directExecutor()
+ .build());
+ RemoteServerCapabilities client =
+ new RemoteServerCapabilities("instance", channel.retain(), null, 3, retrier);
+
+ assertThat(client.get("build-req-id", "command-id")).isEqualTo(caps);
+ }
+
@Test
public void testGetCapabilitiesWithRetries() throws Exception {
ServerCapabilities caps =