Propagate response extensions from BlazeRuntime#afterCommand to the CommandService API.

PiperOrigin-RevId: 356763461
diff --git a/src/main/java/com/google/devtools/build/lib/BUILD b/src/main/java/com/google/devtools/build/lib/BUILD
index e4aa445..6b7405f 100644
--- a/src/main/java/com/google/devtools/build/lib/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/BUILD
@@ -215,6 +215,7 @@
         "//src/main/java/com/google/devtools/build/lib/util:crash_failure_details",
         "//src/main/java/com/google/devtools/build/lib/util:detailed_exit_code",
         "//src/main/java/com/google/devtools/build/lib/util:exit_code",
+        "//src/main/protobuf:any_java_proto",
         "//src/main/protobuf:command_server_java_proto",
         "//src/main/protobuf:failure_details_java_proto",
         "//third_party:guava",
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/BlazeCommandResult.java b/src/main/java/com/google/devtools/build/lib/runtime/BlazeCommandResult.java
index f4b7003..77191d7 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/BlazeCommandResult.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/BlazeCommandResult.java
@@ -16,12 +16,14 @@
 
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.bugreport.Crash;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
 import com.google.devtools.build.lib.server.CommandProtos.ExecRequest;
 import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
 import com.google.devtools.build.lib.util.DetailedExitCode;
 import com.google.devtools.build.lib.util.ExitCode;
+import com.google.protobuf.Any;
 import javax.annotation.Nullable;
 
 /**
@@ -34,13 +36,23 @@
   private final DetailedExitCode detailedExitCode;
 
   @Nullable private final ExecRequest execDescription;
+  private final ImmutableList<Any> responseExtensions;
   private final boolean shutdown;
 
   private BlazeCommandResult(
-      DetailedExitCode detailedExitCode, @Nullable ExecRequest execDescription, boolean shutdown) {
+      DetailedExitCode detailedExitCode,
+      @Nullable ExecRequest execDescription,
+      boolean shutdown,
+      ImmutableList<Any> responseExtensions) {
     this.detailedExitCode = Preconditions.checkNotNull(detailedExitCode);
     this.execDescription = execDescription;
     this.shutdown = shutdown;
+    this.responseExtensions = responseExtensions;
+  }
+
+  private BlazeCommandResult(
+      DetailedExitCode detailedExitCode, @Nullable ExecRequest execDescription, boolean shutdown) {
+    this(detailedExitCode, execDescription, shutdown, ImmutableList.of());
   }
 
   public ExitCode getExitCode() {
@@ -69,6 +81,10 @@
     return detailedExitCode.isSuccess();
   }
 
+  public ImmutableList<Any> getResponseExtensions() {
+    return responseExtensions;
+  }
+
   @Override
   public String toString() {
     return MoreObjects.toStringHelper(this)
@@ -99,6 +115,12 @@
     return new BlazeCommandResult(detailedExitCode, null, false);
   }
 
+  public static BlazeCommandResult withResponseExtensions(
+      BlazeCommandResult result, ImmutableList<Any> responseExtensions) {
+    return new BlazeCommandResult(
+        result.detailedExitCode, result.execDescription, result.shutdown, responseExtensions);
+  }
+
   public static BlazeCommandResult execute(ExecRequest execDescription) {
     return new BlazeCommandResult(
         DetailedExitCode.success(), Preconditions.checkNotNull(execDescription), false);
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java b/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
index 177b964..e8b6b8f 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
@@ -681,7 +681,8 @@
     actionKeyContext.clear();
     DebugLoggerConfigurator.flushServerLog();
     storedExitCode.set(null);
-    return finalCommandResult;
+    return BlazeCommandResult.withResponseExtensions(
+        finalCommandResult, env.getResponseExtensions());
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/CommandEnvironment.java b/src/main/java/com/google/devtools/build/lib/runtime/CommandEnvironment.java
index 088c6d5..c87d85d 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/CommandEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/CommandEnvironment.java
@@ -101,6 +101,7 @@
   private final Duration waitTime;
   private final long commandStartTime;
   private final ImmutableList<Any> commandExtensions;
+  private final ImmutableList.Builder<Any> responseExtensions = ImmutableList.builder();
 
   private OutputService outputService;
   private TopDownActionCache topDownActionCache;
@@ -821,4 +822,19 @@
   public ImmutableList<Any> getCommandExtensions() {
     return commandExtensions;
   }
+
+  /**
+   * Returns the {@linkplain
+   * com.google.devtools.build.lib.server.CommandProtos.RunResponse#getCommandExtensions extensions}
+   * to be passed to the client for this command.
+   *
+   * <p>Extensions are arbitrary messages containing additional execution results.
+   */
+  public ImmutableList<Any> getResponseExtensions() {
+    return responseExtensions.build();
+  }
+
+  public void addResponseExtensions(Iterable<Any> extensions) {
+    responseExtensions.addAll(extensions);
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/server/GrpcServerImpl.java b/src/main/java/com/google/devtools/build/lib/server/GrpcServerImpl.java
index 7362d7c..5a8539e 100644
--- a/src/main/java/com/google/devtools/build/lib/server/GrpcServerImpl.java
+++ b/src/main/java/com/google/devtools/build/lib/server/GrpcServerImpl.java
@@ -589,7 +589,7 @@
     }
 
     try {
-      observer.onNext(response.build());
+      observer.onNext(response.addAllCommandExtensions(result.getResponseExtensions()).build());
       observer.onCompleted();
     } catch (StatusRuntimeException e) {
       logger.atInfo().withCause(e).log(
diff --git a/src/main/protobuf/BUILD b/src/main/protobuf/BUILD
index f64c1d4..0a178ee 100644
--- a/src/main/protobuf/BUILD
+++ b/src/main/protobuf/BUILD
@@ -94,6 +94,11 @@
     deps = ["@com_google_protobuf//:any_proto"],
 )
 
+java_proto_library(
+    name = "wrappers_java_proto",
+    deps = ["@com_google_protobuf//:wrappers_proto"],
+)
+
 proto_library(
     name = "command_server_proto",
     srcs = ["command_server.proto"],
diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD
index 3a7a78e..b19e1d0 100644
--- a/src/test/java/com/google/devtools/build/lib/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/BUILD
@@ -434,10 +434,12 @@
         "//src/main/java/com/google/devtools/common/options",
         "//src/main/java/com/google/devtools/common/options:invocation_policy",
         "//src/main/java/net/starlark/java/syntax",
+        "//src/main/protobuf:any_java_proto",
         "//src/main/protobuf:command_line_java_proto",
         "//src/main/protobuf:failure_details_java_proto",
         "//src/main/protobuf:invocation_policy_java_proto",
         "//src/main/protobuf:test_status_java_proto",
+        "//src/main/protobuf:wrappers_java_proto",
         "//src/test/java/com/google/devtools/build/lib/actions/util",
         "//src/test/java/com/google/devtools/build/lib/events:testutil",
         "//src/test/java/com/google/devtools/build/lib/starlark/util",
@@ -454,6 +456,8 @@
         "//third_party:junit4",
         "//third_party:mockito",
         "//third_party:truth",
+        "//third_party/protobuf",
+        "//third_party/protobuf:protobuf_java",
     ],
 )
 
diff --git a/src/test/java/com/google/devtools/build/lib/runtime/BlazeRuntimeTest.java b/src/test/java/com/google/devtools/build/lib/runtime/BlazeRuntimeTest.java
index a92d70a..8e1040e 100644
--- a/src/test/java/com/google/devtools/build/lib/runtime/BlazeRuntimeTest.java
+++ b/src/test/java/com/google/devtools/build/lib/runtime/BlazeRuntimeTest.java
@@ -32,6 +32,10 @@
 import com.google.devtools.common.options.OptionsBase;
 import com.google.devtools.common.options.OptionsParser;
 import com.google.devtools.common.options.OptionsParsingResult;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.StringValue;
 import java.util.Arrays;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -120,6 +124,49 @@
   }
 
   @Test
+  public void resultExtensions() throws Exception {
+    FileSystem fs = new InMemoryFileSystem(DigestHashFunction.SHA256);
+    ServerDirectories serverDirectories =
+        new ServerDirectories(
+            fs.getPath("/install"), fs.getPath("/output"), fs.getPath("/output_user"));
+    BlazeRuntime runtime =
+        new BlazeRuntime.Builder()
+            .addBlazeModule(
+                new BlazeModule() {
+                  @Override
+                  public BuildOptions getDefaultBuildOptions(BlazeRuntime runtime) {
+                    return BuildOptions.builder().build();
+                  }
+                })
+            .setFileSystem(fs)
+            .setProductName("bazel")
+            .setServerDirectories(serverDirectories)
+            .setStartupOptionsProvider(Mockito.mock(OptionsParsingResult.class))
+            .build();
+    BlazeDirectories directories =
+        new BlazeDirectories(
+            serverDirectories, fs.getPath("/workspace"), fs.getPath("/system_javabase"), "blaze");
+    BlazeWorkspace workspace = runtime.initWorkspace(directories, BinTools.empty(directories));
+    CommandEnvironment env =
+        new CommandEnvironment(
+            runtime,
+            workspace,
+            Mockito.mock(EventBus.class),
+            Thread.currentThread(),
+            VersionCommand.class.getAnnotation(Command.class),
+            OptionsParser.builder().optionsClasses(COMMAND_ENV_REQUIRED_OPTIONS).build(),
+            ImmutableList.of(),
+            0L,
+            0L,
+            ImmutableList.of());
+    Any anyFoo = Any.pack(StringValue.of("foo"));
+    Any anyBar = Any.pack(BytesValue.of(ByteString.copyFromUtf8("bar")));
+    env.addResponseExtensions(ImmutableList.of(anyFoo, anyBar));
+    assertThat(runtime.afterCommand(env, BlazeCommandResult.success()).getResponseExtensions())
+        .containsExactly(anyFoo, anyBar);
+  }
+
+  @Test
   public void addsCommandsFromModules() throws Exception {
     FileSystem fs = new InMemoryFileSystem(DigestHashFunction.SHA256);
     ServerDirectories serverDirectories =
diff --git a/src/test/java/com/google/devtools/build/lib/server/GrpcServerTest.java b/src/test/java/com/google/devtools/build/lib/server/GrpcServerTest.java
index 6b04700..d0b73de 100644
--- a/src/test/java/com/google/devtools/build/lib/server/GrpcServerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/server/GrpcServerTest.java
@@ -16,6 +16,7 @@
 import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertThrows;
 
+import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.clock.JavaClock;
 import com.google.devtools.build.lib.runtime.BlazeCommandResult;
 import com.google.devtools.build.lib.runtime.CommandDispatcher;
@@ -43,6 +44,8 @@
 import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
 import com.google.protobuf.Any;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.StringValue;
 import io.grpc.ManagedChannel;
 import io.grpc.Server;
 import io.grpc.inprocess.InProcessChannelBuilder;
@@ -236,7 +239,11 @@
             } catch (IOException e) {
               throw new IllegalStateException(e);
             }
-            return BlazeCommandResult.success();
+            return BlazeCommandResult.withResponseExtensions(
+                BlazeCommandResult.success(),
+                ImmutableList.of(
+                    Any.pack(StringValue.of("foo")),
+                    Any.pack(BytesValue.of(ByteString.copyFromUtf8("bar")))));
           }
         };
     createServer(dispatcher);
@@ -255,10 +262,15 @@
     for (int i = 1; i < 11; i++) {
       assertThat(responses.get(i).getFinished()).isFalse();
       assertThat(responses.get(i).getStandardOutput().toByteArray()).isEqualTo(new byte[1024]);
+      assertThat(responses.get(i).getCommandExtensionsList()).isEmpty();
     }
     assertThat(responses.get(11).getFinished()).isTrue();
     assertThat(responses.get(11).getExitCode()).isEqualTo(0);
     assertThat(responses.get(11).hasFailureDetail()).isFalse();
+    assertThat(responses.get(11).getCommandExtensionsList())
+        .containsExactly(
+            Any.pack(StringValue.of("foo")),
+            Any.pack(BytesValue.of(ByteString.copyFromUtf8("bar"))));
   }
 
   @Test