Intercept capabilities and uploader requests and add custom grpc headers

Following #10015. Some requests do not use the custom headers.

Closes #10634.

PiperOrigin-RevId: 298574179
diff --git a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
index d709cfe..52c81cd 100644
--- a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
+++ b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
@@ -54,7 +54,7 @@
       String target,
       String proxy,
       AuthAndTLSOptions options,
-      @Nullable ClientInterceptor interceptor)
+      @Nullable List<ClientInterceptor> interceptors)
       throws IOException {
     Preconditions.checkNotNull(target);
     Preconditions.checkNotNull(options);
@@ -69,8 +69,8 @@
           newNettyChannelBuilder(targetUrl, proxy)
               .negotiationType(
                   isTlsEnabled(target) ? NegotiationType.TLS : NegotiationType.PLAINTEXT);
-      if (interceptor != null) {
-        builder.intercept(interceptor);
+      if (interceptors != null) {
+        builder.intercept(interceptors);
       }
       if (sslContext != null) {
         builder.sslContext(sslContext);
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index 1750526..c53186a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -122,7 +122,6 @@
   private ContentAddressableStorageFutureStub casFutureStub() {
     return ContentAddressableStorageGrpc.newFutureStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
         .withCallCredentials(credentials)
         .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
   }
@@ -130,7 +129,6 @@
   private ByteStreamStub bsAsyncStub() {
     return ByteStreamGrpc.newStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
         .withCallCredentials(credentials)
         .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
   }
@@ -138,7 +136,6 @@
   private ActionCacheBlockingStub acBlockingStub() {
     return ActionCacheGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
         .withCallCredentials(credentials)
         .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
   }
@@ -146,7 +143,6 @@
   private ActionCacheFutureStub acFutureStub() {
     return ActionCacheGrpc.newFutureStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withInterceptors(TracingMetadataUtils.newCacheHeadersInterceptor(options))
         .withCallCredentials(credentials)
         .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
   }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
index 306643b..52fc07f 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -59,7 +59,6 @@
   private ExecutionBlockingStub execBlockingStub() {
     return ExecutionGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withInterceptors(TracingMetadataUtils.newExecHeadersInterceptor(options))
         .withCallCredentials(callCredentials);
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
index ff00eb6..34b0bfe 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCacheClientFactory.java
@@ -32,6 +32,7 @@
 import io.netty.channel.unix.DomainSocketAddress;
 import java.io.IOException;
 import java.net.URI;
+import java.util.List;
 import javax.annotation.Nullable;
 
 /**
@@ -59,10 +60,10 @@
       String target,
       String proxyUri,
       AuthAndTLSOptions authOptions,
-      @Nullable ClientInterceptor interceptor)
+      @Nullable List<ClientInterceptor> interceptors)
       throws IOException {
     return new ReferenceCountedChannel(
-        GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptor));
+        GoogleAuthUtils.newChannel(target, proxyUri, authOptions, interceptors));
   }
 
   public static RemoteCacheClient create(
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
index 26e54f3..7a342f1 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -107,6 +107,37 @@
     return !Strings.isNullOrEmpty(options.remoteExecutor);
   }
 
+  private void verifyServerCapabilities(
+      RemoteOptions remoteOptions,
+      ReferenceCountedChannel channel,
+      CallCredentials credentials,
+      RemoteRetrier retrier,
+      CommandEnvironment env,
+      DigestUtil digestUtil)
+      throws AbruptExitException {
+    RemoteServerCapabilities rsc =
+        new RemoteServerCapabilities(
+            remoteOptions.remoteInstanceName,
+            channel,
+            credentials,
+            remoteOptions.remoteTimeout,
+            retrier);
+    ServerCapabilities capabilities = null;
+    try {
+      capabilities = rsc.get(env.getCommandId().toString(), env.getBuildRequestId());
+    } catch (IOException e) {
+      throw new AbruptExitException(
+          "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
+          ExitCode.REMOTE_ERROR,
+          e);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      return;
+    }
+    checkClientServerCompatibility(
+        capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());
+  }
+
   @Override
   public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
     Preconditions.checkState(actionContextProvider == null, "actionContextProvider must be null");
@@ -178,12 +209,17 @@
       ReferenceCountedChannel execChannel = null;
       ReferenceCountedChannel cacheChannel = null;
       if (enableRemoteExecution) {
+        ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
+        interceptors.add(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions));
+        if (loggingInterceptor != null) {
+          interceptors.add(loggingInterceptor);
+        }
         execChannel =
             RemoteCacheClientFactory.createGrpcChannel(
                 remoteOptions.remoteExecutor,
                 remoteOptions.remoteProxy,
                 authAndTlsOptions,
-                loggingInterceptor);
+                interceptors.build());
         // Create a separate channel if --remote_executor and --remote_cache point to different
         // endpoints.
         if (Strings.isNullOrEmpty(remoteOptions.remoteCache)
@@ -193,12 +229,17 @@
       }
 
       if (cacheChannel == null) {
+        ImmutableList.Builder<ClientInterceptor> interceptors = ImmutableList.builder();
+        interceptors.add(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions));
+        if (loggingInterceptor != null) {
+          interceptors.add(loggingInterceptor);
+        }
         cacheChannel =
             RemoteCacheClientFactory.createGrpcChannel(
                 remoteOptions.remoteCache,
                 remoteOptions.remoteProxy,
                 authAndTlsOptions,
-                loggingInterceptor);
+                interceptors.build());
       }
 
       CallCredentials credentials = GoogleAuthUtils.newCallCredentials(authAndTlsOptions);
@@ -212,27 +253,13 @@
       // We always query the execution server for capabilities, if it is defined. A remote
       // execution/cache system should have all its servers to return the capabilities pertaining
       // to the system as a whole.
-      RemoteServerCapabilities rsc =
-          new RemoteServerCapabilities(
-              remoteOptions.remoteInstanceName,
-              (execChannel != null ? execChannel : cacheChannel),
-              credentials,
-              remoteOptions.remoteTimeout,
-              retrier);
-      ServerCapabilities capabilities = null;
-      try {
-        capabilities = rsc.get(buildRequestId, invocationId);
-      } catch (IOException e) {
-        throw new AbruptExitException(
-            "Failed to query remote execution capabilities: " + Utils.grpcAwareErrorMessage(e),
-            ExitCode.REMOTE_ERROR,
-            e);
-      } catch (InterruptedException e) {
-        Thread.currentThread().interrupt();
-        return;
+      if (execChannel != null) {
+        verifyServerCapabilities(remoteOptions, execChannel, credentials, retrier, env, digestUtil);
       }
-      checkClientServerCompatibility(
-          capabilities, remoteOptions, digestUtil.getDigestFunction(), env.getReporter());
+      if (cacheChannel != execChannel) {
+        verifyServerCapabilities(
+            remoteOptions, cacheChannel, credentials, retrier, env, digestUtil);
+      }
 
       ByteStreamUploader uploader =
           new ByteStreamUploader(
@@ -241,6 +268,7 @@
               credentials,
               remoteOptions.remoteTimeout,
               retrier);
+
       cacheChannel.release();
       RemoteCacheClient cacheClient =
           new GrpcCacheClient(
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index 2b47d94..8592412 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -46,6 +46,7 @@
 import io.grpc.ServerCall;
 import io.grpc.ServerCall.Listener;
 import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
 import io.grpc.ServerInterceptors;
 import io.grpc.ServerServiceDefinition;
 import io.grpc.Status;
@@ -53,6 +54,7 @@
 import io.grpc.StatusRuntimeException;
 import io.grpc.inprocess.InProcessChannelBuilder;
 import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.stub.MetadataUtils;
 import io.grpc.stub.StreamObserver;
 import io.grpc.util.MutableHandlerRegistry;
 import java.io.ByteArrayInputStream;
@@ -690,6 +692,74 @@
   }
 
   @Test
+  public void customHeadersAreAttachedToRequest() throws Exception {
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);
+
+    Metadata metadata = new Metadata();
+    metadata.put(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER), "Value1");
+    metadata.put(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER), "Value2");
+
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(
+                InProcessChannelBuilder.forName("Server for " + this.getClass())
+                    .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
+                    .build()),
+            null, /* timeout seconds */
+            60,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE];
+    Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    serviceRegistry.addService(
+        ServerInterceptors.intercept(
+            new ByteStreamImplBase() {
+              @Override
+              public StreamObserver<WriteRequest> write(
+                  StreamObserver<WriteResponse> streamObserver) {
+                return new StreamObserver<WriteRequest>() {
+                  @Override
+                  public void onNext(WriteRequest writeRequest) {}
+
+                  @Override
+                  public void onError(Throwable throwable) {
+                    fail("onError should never be called.");
+                  }
+
+                  @Override
+                  public void onCompleted() {
+                    WriteResponse response =
+                        WriteResponse.newBuilder().setCommittedSize(blob.length).build();
+                    streamObserver.onNext(response);
+                    streamObserver.onCompleted();
+                  }
+                };
+              }
+            },
+            new ServerInterceptor() {
+              @Override
+              public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+                  ServerCall<ReqT, RespT> call,
+                  Metadata metadata,
+                  ServerCallHandler<ReqT, RespT> next) {
+                assertThat(metadata.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
+                    .isEqualTo("Value1");
+                assertThat(metadata.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
+                    .isEqualTo("Value2");
+                assertThat(metadata.get(Metadata.Key.of("Key3", Metadata.ASCII_STRING_MARSHALLER)))
+                    .isEqualTo(null);
+                return next.startCall(call, metadata);
+              }
+            }));
+
+    uploader.uploadBlob(hash, chunker, true);
+  }
+
+  @Test
   public void sameBlobShouldNotBeUploadedTwice() throws Exception {
     // Test that uploading the same file concurrently triggers only one file upload.
 
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
index 9c551d8..240947d 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
@@ -226,6 +226,7 @@
             InProcessChannelBuilder.forName(fakeServerName)
                 .directExecutor()
                 .intercept(new CallCredentialsInterceptor(creds))
+                .intercept(TracingMetadataUtils.newCacheHeadersInterceptor(remoteOptions))
                 .build());
     ByteStreamUploader uploader =
         new ByteStreamUploader(
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
index 43fec08..fab371c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
@@ -233,7 +233,10 @@
             retryService);
     ReferenceCountedChannel channel =
         new ReferenceCountedChannel(
-            InProcessChannelBuilder.forName(fakeServerName).directExecutor().build());
+            InProcessChannelBuilder.forName(fakeServerName)
+                .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
+                .directExecutor()
+                .build());
     GrpcRemoteExecutor executor =
         new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
     CallCredentials creds =
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
index 2c47b44..357446c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteServerCapabilitiesTest.java
@@ -25,6 +25,8 @@
 import build.bazel.remote.execution.v2.PriorityCapabilities.PriorityRange;
 import build.bazel.remote.execution.v2.RequestMetadata;
 import build.bazel.remote.execution.v2.ServerCapabilities;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
@@ -107,6 +109,60 @@
     }
   }
 
+  private static class RequestCustomHeadersValidator implements ServerInterceptor {
+    @Override
+    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+        ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
+      assertThat(headers.get(Metadata.Key.of("Key1", Metadata.ASCII_STRING_MARSHALLER)))
+          .isEqualTo("Value1");
+      assertThat(headers.get(Metadata.Key.of("Key2", Metadata.ASCII_STRING_MARSHALLER)))
+          .isEqualTo("Value2");
+      return next.startCall(call, headers);
+    }
+  }
+
+  @Test
+  public void testCustomHeadersAreAttached() throws Exception {
+    ServerCapabilities caps =
+        ServerCapabilities.newBuilder()
+            .setExecutionCapabilities(
+                ExecutionCapabilities.newBuilder().setExecEnabled(true).build())
+            .build();
+    serviceRegistry.addService(
+        ServerInterceptors.intercept(
+            new CapabilitiesImplBase() {
+              @Override
+              public void getCapabilities(
+                  GetCapabilitiesRequest request,
+                  StreamObserver<ServerCapabilities> responseObserver) {
+                responseObserver.onNext(caps);
+                responseObserver.onCompleted();
+              }
+            },
+            new RequestCustomHeadersValidator()));
+
+    RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class);
+    remoteOptions.remoteHeaders =
+        ImmutableList.of(
+            Maps.immutableEntry("Key1", "Value1"), Maps.immutableEntry("Key2", "Value2"));
+
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(
+            () -> new ExponentialBackoff(remoteOptions),
+            RemoteRetrier.RETRIABLE_GRPC_ERRORS,
+            retryService);
+    ReferenceCountedChannel channel =
+        new ReferenceCountedChannel(
+            InProcessChannelBuilder.forName(fakeServerName)
+                .intercept(TracingMetadataUtils.newExecHeadersInterceptor(remoteOptions))
+                .directExecutor()
+                .build());
+    RemoteServerCapabilities client =
+        new RemoteServerCapabilities("instance", channel.retain(), null, 3, retrier);
+
+    assertThat(client.get("build-req-id", "command-id")).isEqualTo(caps);
+  }
+
   @Test
   public void testGetCapabilitiesWithRetries() throws Exception {
     ServerCapabilities caps =