[bes] Attach `RequestMetadata` to RPC calls

This allows servers to trace the requests similarly to other RPCs.

Closes #16359.

PiperOrigin-RevId: 479292105
Change-Id: Ic6598175171577c6fce23a3bfd637b1b12b6a916
diff --git a/src/main/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModule.java b/src/main/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModule.java
index 67ecd82..ac876b0 100644
--- a/src/main/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModule.java
+++ b/src/main/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModule.java
@@ -102,7 +102,9 @@
           new BuildEventServiceGrpcClient(
               newGrpcChannel(config),
               credentials != null ? MoreCallCredentials.from(credentials) : null,
-              makeGrpcInterceptor(config));
+              makeGrpcInterceptor(config),
+              env.getBuildRequestId(),
+              env.getCommandId());
     }
     return client;
   }
diff --git a/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BUILD b/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BUILD
index 4235d80..58d0e49 100644
--- a/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BUILD
@@ -16,6 +16,7 @@
         "//third_party:netty_tcnative",
     ],
     deps = [
+        "//src/main/java/com/google/devtools/build/lib/remote/util",
         "//third_party:guava",
         "//third_party:jsr305",
         "//third_party/grpc-java:grpc-jar",
diff --git a/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BuildEventServiceGrpcClient.java b/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BuildEventServiceGrpcClient.java
index a71964f..f4293a3 100644
--- a/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BuildEventServiceGrpcClient.java
+++ b/src/main/java/com/google/devtools/build/lib/buildeventservice/client/BuildEventServiceGrpcClient.java
@@ -17,10 +17,12 @@
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
+import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
 import com.google.devtools.build.v1.PublishBuildEventGrpc;
 import com.google.devtools.build.v1.PublishBuildEventGrpc.PublishBuildEventBlockingStub;
 import com.google.devtools.build.v1.PublishBuildEventGrpc.PublishBuildEventStub;
@@ -36,6 +38,7 @@
 import io.grpc.stub.AbstractStub;
 import io.grpc.stub.StreamObserver;
 import java.time.Duration;
+import java.util.UUID;
 import javax.annotation.Nullable;
 
 /** Implementation of BuildEventServiceClient that uploads data using gRPC. */
@@ -48,15 +51,22 @@
   private final PublishBuildEventStub besAsync;
   private final PublishBuildEventBlockingStub besBlocking;
 
+  private final String buildRequestId;
+  private final UUID commandId;
+
   public BuildEventServiceGrpcClient(
       ManagedChannel channel,
       @Nullable CallCredentials callCredentials,
-      ClientInterceptor interceptor) {
+      ClientInterceptor interceptor,
+      String buildRequestId,
+      UUID commandId) {
     this.besAsync =
         configureStub(PublishBuildEventGrpc.newStub(channel), callCredentials, interceptor);
     this.besBlocking =
         configureStub(PublishBuildEventGrpc.newBlockingStub(channel), callCredentials, interceptor);
     this.channel = channel;
+    this.buildRequestId = Preconditions.checkNotNull(buildRequestId);
+    this.commandId = Preconditions.checkNotNull(commandId);
   }
 
   @VisibleForTesting
@@ -67,6 +77,8 @@
     this.besAsync = besAsync;
     this.besBlocking = besBlocking;
     this.channel = channel;
+    this.buildRequestId = "testing/" + UUID.randomUUID();
+    this.commandId = UUID.randomUUID();
   }
 
   private static <T extends AbstractStub<T>> T configureStub(
@@ -83,6 +95,13 @@
     try {
       besBlocking
           .withDeadlineAfter(RPC_TIMEOUT.toMillis(), MILLISECONDS)
+          .withInterceptors(
+              TracingMetadataUtils.attachMetadataInterceptor(
+                  TracingMetadataUtils.buildMetadata(
+                      buildRequestId,
+                      commandId.toString(),
+                      "publish_lifecycle_event",
+                      /* actionMetadata= */ null)))
           .publishLifecycleEvent(lifecycleEvent);
     } catch (StatusRuntimeException e) {
       Throwables.throwIfInstanceOf(Throwables.getRootCause(e), InterruptedException.class);
@@ -94,36 +113,49 @@
     private final StreamObserver<PublishBuildToolEventStreamRequest> stream;
     private final SettableFuture<Status> streamStatus;
 
-    public BESGrpcStreamContext(PublishBuildEventStub besAsync, AckCallback ackCallback) {
+    public BESGrpcStreamContext(
+        PublishBuildEventStub besAsync,
+        AckCallback ackCallback,
+        String buildRequestId,
+        UUID commandId) {
       this.streamStatus = SettableFuture.create();
       this.stream =
-          besAsync.publishBuildToolEventStream(
-              new StreamObserver<PublishBuildToolEventStreamResponse>() {
-                @Override
-                public void onNext(PublishBuildToolEventStreamResponse response) {
-                  ackCallback.apply(response);
-                }
+          besAsync
+              .withInterceptors(
+                  TracingMetadataUtils.attachMetadataInterceptor(
+                      TracingMetadataUtils.buildMetadata(
+                          buildRequestId,
+                          commandId.toString(),
+                          "publish_build_tool_event_stream",
+                          /* actionMetadata= */ null)))
+              .publishBuildToolEventStream(
+                  new StreamObserver<PublishBuildToolEventStreamResponse>() {
+                    @Override
+                    public void onNext(PublishBuildToolEventStreamResponse response) {
+                      ackCallback.apply(response);
+                    }
 
-                @Override
-                public void onError(Throwable t) {
-                  Status error = Status.fromThrowable(t);
-                  if (error.getCode() == Status.CANCELLED.getCode()
-                      && error.getCause() != null
-                      && Status.fromThrowable(error.getCause()).getCode()
-                          != Status.UNKNOWN.getCode()) {
-                    // gRPC likes to wrap Status(Runtime)Exceptions in StatusRuntimeExceptions. If
-                    // the status is cancelled and has a Status(Runtime)Exception as a cause it
-                    // means the error was generated client side.
-                    error = Status.fromThrowable(error.getCause());
-                  }
-                  streamStatus.set(error);
-                }
+                    @Override
+                    public void onError(Throwable t) {
+                      Status error = Status.fromThrowable(t);
+                      if (error.getCode() == Status.CANCELLED.getCode()
+                          && error.getCause() != null
+                          && Status.fromThrowable(error.getCause()).getCode()
+                              != Status.UNKNOWN.getCode()) {
+                        // gRPC likes to wrap Status(Runtime)Exceptions in StatusRuntimeExceptions.
+                        // If
+                        // the status is cancelled and has a Status(Runtime)Exception as a cause it
+                        // means the error was generated client side.
+                        error = Status.fromThrowable(error.getCause());
+                      }
+                      streamStatus.set(error);
+                    }
 
-                @Override
-                public void onCompleted() {
-                  streamStatus.set(Status.OK);
-                }
-              });
+                    @Override
+                    public void onCompleted() {
+                      streamStatus.set(Status.OK);
+                    }
+                  });
     }
 
     @Override
@@ -157,7 +189,7 @@
   @Override
   public StreamContext openStream(AckCallback ackCallback) throws InterruptedException {
     try {
-      return new BESGrpcStreamContext(besAsync, ackCallback);
+      return new BESGrpcStreamContext(besAsync, ackCallback, buildRequestId, commandId);
     } catch (StatusRuntimeException e) {
       Throwables.throwIfInstanceOf(Throwables.getRootCause(e), InterruptedException.class);
       ListenableFuture<Status> status = Futures.immediateFuture(Status.fromThrowable(e));
diff --git a/src/test/java/com/google/devtools/build/lib/buildeventservice/BUILD b/src/test/java/com/google/devtools/build/lib/buildeventservice/BUILD
index ab183f2..793238e 100644
--- a/src/test/java/com/google/devtools/build/lib/buildeventservice/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/buildeventservice/BUILD
@@ -56,6 +56,7 @@
         "//src/main/java/com/google/devtools/build/lib/buildeventstream/transports",
         "//src/main/java/com/google/devtools/build/lib/network:connectivity_status",
         "//src/main/java/com/google/devtools/build/lib/network:noop_connectivity",
+        "//src/main/java/com/google/devtools/build/lib/remote/util",
         "//src/main/java/com/google/devtools/build/lib/util:abrupt_exit_exception",
         "//src/main/java/com/google/devtools/build/lib/util:exit_code",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
@@ -71,5 +72,6 @@
         "@googleapis//:google_devtools_build_v1_build_events_java_proto",
         "@googleapis//:google_devtools_build_v1_publish_build_event_java_grpc",
         "@googleapis//:google_devtools_build_v1_publish_build_event_java_proto",
+        "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto",
     ],
 )
diff --git a/src/test/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModuleTest.java b/src/test/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModuleTest.java
index ecd6483..8063ae0 100644
--- a/src/test/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModuleTest.java
+++ b/src/test/java/com/google/devtools/build/lib/buildeventservice/BazelBuildEventServiceModuleTest.java
@@ -20,6 +20,7 @@
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assume.assumeFalse;
 
+import build.bazel.remote.execution.v2.RequestMetadata;
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
@@ -56,6 +57,7 @@
 import com.google.devtools.build.lib.network.ConnectivityStatus;
 import com.google.devtools.build.lib.network.ConnectivityStatusProvider;
 import com.google.devtools.build.lib.network.NoOpConnectivityModule;
+import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
 import com.google.devtools.build.lib.runtime.BlazeModule;
 import com.google.devtools.build.lib.runtime.BlazeRuntime;
 import com.google.devtools.build.lib.runtime.CommandEnvironment;
@@ -74,6 +76,7 @@
 import io.grpc.ManagedChannel;
 import io.grpc.Metadata;
 import io.grpc.Server;
+import io.grpc.ServerInterceptors;
 import io.grpc.Status;
 import io.grpc.StatusRuntimeException;
 import io.grpc.inprocess.InProcessChannelBuilder;
@@ -180,7 +183,9 @@
 
   @Before
   public void setUp() throws Exception {
-    serviceRegistry.addService(buildEventService);
+    serviceRegistry.addService(
+        ServerInterceptors.intercept(
+            buildEventService, new TracingMetadataUtils.ServerHeadersInterceptor()));
     fakeServer =
         InProcessServerBuilder.forName(fakeServerName)
             .fallbackHandlerRegistry(serviceRegistry)
@@ -960,6 +965,11 @@
     @Override
     public void publishLifecycleEvent(
         PublishLifecycleEventRequest request, StreamObserver<Empty> responseObserver) {
+      RequestMetadata metadata = TracingMetadataUtils.fromCurrentContext();
+      assertThat(metadata.getToolInvocationId()).isNotEmpty();
+      assertThat(metadata.getCorrelatedInvocationsId()).isNotEmpty();
+      assertThat(metadata.getActionId()).isEqualTo("publish_lifecycle_event");
+
       responseObserver.onNext(Empty.getDefaultInstance());
       responseObserver.onCompleted();
     }
@@ -968,6 +978,11 @@
     public synchronized StreamObserver<PublishBuildToolEventStreamRequest>
         publishBuildToolEventStream(
             StreamObserver<PublishBuildToolEventStreamResponse> responseObserver) {
+      RequestMetadata metadata = TracingMetadataUtils.fromCurrentContext();
+      assertThat(metadata.getToolInvocationId()).isNotEmpty();
+      assertThat(metadata.getCorrelatedInvocationsId()).isNotEmpty();
+      assertThat(metadata.getActionId()).isEqualTo("publish_build_tool_event_stream");
+
       if (errorMessage != null) {
         return new ErroringPublishBuildStreamObserver(responseObserver, errorMessage);
       }
diff --git a/src/test/java/com/google/devtools/build/lib/buildeventservice/BuildEventServiceGrpcClientTest.java b/src/test/java/com/google/devtools/build/lib/buildeventservice/BuildEventServiceGrpcClientTest.java
index 80b8d03..12820d5 100644
--- a/src/test/java/com/google/devtools/build/lib/buildeventservice/BuildEventServiceGrpcClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/buildeventservice/BuildEventServiceGrpcClientTest.java
@@ -120,7 +120,12 @@
       extraHeaders.put(Metadata.Key.of("metadata-foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
       ClientInterceptor interceptor = MetadataUtils.newAttachHeadersInterceptor(extraHeaders);
       BuildEventServiceGrpcClient grpcClient =
-          new BuildEventServiceGrpcClient(server.getChannel(), null, interceptor);
+          new BuildEventServiceGrpcClient(
+              server.getChannel(),
+              null,
+              interceptor,
+              "testing/" + UUID.randomUUID(),
+              UUID.randomUUID());
       assertThat(grpcClient.openStream(ack -> {}).getStatus().get()).isEqualTo(Status.OK);
       assertThat(seenHeaders).hasSize(1);
       Metadata headers = seenHeaders.get(0);
@@ -133,7 +138,12 @@
   public void immediateSuccess() throws Exception {
     try (TestServer server = startTestServer(NOOP_SERVER.bindService())) {
       assertThat(
-              new BuildEventServiceGrpcClient(server.getChannel(), null, null)
+              new BuildEventServiceGrpcClient(
+                      server.getChannel(),
+                      null,
+                      null,
+                      "testing/" + UUID.randomUUID(),
+                      UUID.randomUUID())
                   .openStream(ack -> {})
                   .getStatus()
                   .get())
@@ -154,7 +164,12 @@
               }
             }.bindService())) {
       assertThat(
-              new BuildEventServiceGrpcClient(server.getChannel(), null, null)
+              new BuildEventServiceGrpcClient(
+                      server.getChannel(),
+                      null,
+                      null,
+                      "testing/" + UUID.randomUUID(),
+                      UUID.randomUUID())
                   .openStream(ack -> {})
                   .getStatus()
                   .get())