[7.6.1] Don't leak in-flight execution in `RemoteSpawnCache` (#25688)

The in-flight executions need to be cleaned up also in case of a cache
hit or an exception while interacting with the cache, otherwise
`inFlightExecutions` grows over time retaining large `Spawn`s.

Closes #25668.

PiperOrigin-RevId: 740058969
Change-Id: I42045219510df0347e38234d6d3bd1ebb6167d9b 
(cherry picked from commit 0f957c185f671cf9d088cf551818f0442a80c6f9)
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
index bfc4342..7f0a48c 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
@@ -1455,9 +1455,11 @@
   }
 
   /** An ongoing local execution of a spawn. */
-  public static final class LocalExecution {
+  public static final class LocalExecution implements SilentCloseable {
     private final RemoteAction action;
     private final SettableFuture<SpawnResult> spawnResultFuture;
+    private final Runnable onClose;
+    private final AtomicBoolean closeManually = new AtomicBoolean(false);
     private final Phaser spawnResultConsumers =
         new Phaser(1) {
           @Override
@@ -1467,9 +1469,10 @@
           }
         };
 
-    private LocalExecution(RemoteAction action) {
+    private LocalExecution(RemoteAction action, Runnable onClose) {
       this.action = action;
       this.spawnResultFuture = SettableFuture.create();
+      this.onClose = onClose;
     }
 
     /**
@@ -1482,11 +1485,11 @@
      * builds and clients.
      */
     @Nullable
-    public static LocalExecution createIfDeduplicatable(RemoteAction action) {
+    public static LocalExecution createIfDeduplicatable(RemoteAction action, Runnable onClose) {
       if (action.getSpawn().getPathMapper().isNoop()) {
         return null;
       }
-      return new LocalExecution(action);
+      return new LocalExecution(action, onClose);
     }
 
     /**
@@ -1522,10 +1525,30 @@
 
     /**
      * Signals to all potential consumers of the {@link #spawnResultFuture} that this execution has
-     * been cancelled and that the result will not be available.
+     * finished or been canceled and that the result will no longer be available.
      */
-    public void cancel() {
+    @Override
+    public void close() {
+      if (!closeManually.get()) {
+        doClose();
+      }
+    }
+
+    /**
+     * Returns a {@link Runnable} that will close this {@link LocalExecution} instance when called.
+     * After this method is called, the {@link LocalExecution} instance will not be closed by the
+     * {@link #close()} method.
+     */
+    public Runnable delayClose() {
+      if (!closeManually.compareAndSet(false, true)) {
+        throw new IllegalStateException("delayClose has already been called");
+      }
+      return this::doClose;
+    }
+
+    private void doClose() {
       spawnResultFuture.cancel(true);
+      onClose.run();
     }
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnCache.java
index ae802b9..b259756 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteSpawnCache.java
@@ -78,6 +78,11 @@
     return remoteExecutionService;
   }
 
+  @VisibleForTesting
+  int getInFlightExecutionsSize() {
+    return inFlightExecutions.size();
+  }
+
   @Override
   public CacheHandle lookup(Spawn spawn, SpawnExecutionContext context)
       throws InterruptedException, IOException, ExecException, ForbiddenActionInputException {
@@ -110,7 +115,9 @@
       // first one.
       LocalExecution previousExecution = null;
       try {
-        thisExecution = LocalExecution.createIfDeduplicatable(action);
+        thisExecution =
+            LocalExecution.createIfDeduplicatable(
+                action, () -> inFlightExecutions.remove(action.getActionKey()));
         if (shouldUploadLocalResults && thisExecution != null) {
           LocalExecution previousOrThisExecution =
               inFlightExecutions.merge(
@@ -126,8 +133,12 @@
                       return thisExecutionArg;
                     }
                   });
-          previousExecution =
-              previousOrThisExecution == thisExecution ? null : previousOrThisExecution;
+          if (previousOrThisExecution != thisExecution) {
+            // The current execution is not the first one to be registered for this action key, so
+            // we need to wait for the previous one to finish before we can reuse its result.
+            previousExecution = previousOrThisExecution;
+            thisExecution = null;
+          }
         }
         try {
           RemoteActionResult result;
@@ -138,6 +149,9 @@
           // In case the remote cache returned a failed action (exit code != 0) we treat it as a
           // cache miss
           if (result != null && result.getExitCode() == 0) {
+            if (thisExecution != null) {
+              thisExecution.close();
+            }
             Stopwatch fetchTime = Stopwatch.createStarted();
             InMemoryOutput inMemoryOutput;
             try (SilentCloseable c = prof.profile(REMOTE_DOWNLOAD, "download outputs")) {
@@ -259,7 +273,9 @@
           // indefinitely is important to avoid excessive memory pressure - Spawns can be very
           // large.
           remoteExecutionService.uploadOutputs(
-              action, result, () -> inFlightExecutions.remove(action.getActionKey()));
+              action,
+              result,
+              thisExecutionFinal != null ? thisExecutionFinal.delayClose() : () -> {});
           if (thisExecutionFinal != null
               && action.getSpawn().getResourceOwner().mayModifySpawnOutputsAfterExecution()) {
             // In this case outputs have been uploaded synchronously and the callback above has run,
@@ -290,7 +306,7 @@
         @Override
         public void close() {
           if (thisExecutionFinal != null) {
-            thisExecutionFinal.cancel();
+            thisExecutionFinal.close();
           }
         }
       };
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnCacheTest.java
index 1b1908a..63ec01c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnCacheTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnCacheTest.java
@@ -772,7 +772,14 @@
     // Simulate a very slow upload to the remote cache to ensure that the second spawn is
     // deduplicated rather than a cache hit. This is a slight hack, but also avoid introducing
     // concurrency to this test.
-    Mockito.doNothing().when(remoteExecutionService).uploadOutputs(any(), any(), any());
+    AtomicReference<Runnable> onUploadComplete = new AtomicReference<>();
+    Mockito.doAnswer(
+            invocationOnMock -> {
+              onUploadComplete.set(invocationOnMock.getArgument(2));
+              return null;
+            })
+        .when(remoteExecutionService)
+        .uploadOutputs(any(), any(), any());
 
     // act
     try (CacheHandle firstCacheHandle = cache.lookup(firstSpawn, firstPolicy)) {
@@ -795,6 +802,8 @@
                 fs.getPath("/exec/root/bazel-bin/k8-opt/bin/output"), UTF_8))
         .isEqualTo("hello");
     assertThat(secondCacheHandle.willStore()).isFalse();
+    onUploadComplete.get().run();
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
   }
 
   @Test
@@ -840,10 +849,12 @@
     // Simulate a very slow upload to the remote cache to ensure that the second spawn is
     // deduplicated rather than a cache hit. This is a slight hack, but also avoids introducing
     // more concurrency to this test.
+    AtomicReference<Runnable> onUploadComplete = new AtomicReference<>();
     Mockito.doAnswer(
             (Answer<Void>)
                 invocation -> {
                   enteredUploadOutputs.countDown();
+                  onUploadComplete.set(invocation.getArgument(2));
                   return null;
                 })
         .when(remoteExecutionService)
@@ -910,6 +921,8 @@
                 fs.getPath("/exec/root/bazel-bin/k8-opt/bin/output"), UTF_8))
         .isEqualTo("hello");
     assertThat(secondCacheHandle.willStore()).isFalse();
+    onUploadComplete.get().run();
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
   }
 
   @Test
@@ -934,7 +947,14 @@
     // Simulate a very slow upload to the remote cache to ensure that the second spawn is
     // deduplicated rather than a cache hit. This is a slight hack, but also avoid introducing
     // concurrency to this test.
-    Mockito.doNothing().when(remoteExecutionService).uploadOutputs(any(), any(), any());
+    AtomicReference<Runnable> onUploadComplete = new AtomicReference<>();
+    Mockito.doAnswer(
+            invocationOnMock -> {
+              onUploadComplete.set(invocationOnMock.getArgument(2));
+              return null;
+            })
+        .when(remoteExecutionService)
+        .uploadOutputs(any(), any(), any());
 
     // act
     try (CacheHandle firstCacheHandle = cache.lookup(firstSpawn, firstPolicy)) {
@@ -957,6 +977,8 @@
         .isEqualTo("in-memory");
     assertThat(execRoot.getRelative(inMemoryOutput.getExecPath()).exists()).isFalse();
     assertThat(secondCacheHandle.willStore()).isFalse();
+    onUploadComplete.get().run();
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
   }
 
   @Test
@@ -978,10 +1000,6 @@
 
     RemoteExecutionService remoteExecutionService = cache.getRemoteExecutionService();
     Mockito.doCallRealMethod().when(remoteExecutionService).waitForAndReuseOutputs(any(), any());
-    // Simulate a very slow upload to the remote cache to ensure that the second spawn is
-    // deduplicated rather than a cache hit. This is a slight hack, but also avoid introducing
-    // concurrency to this test.
-    Mockito.doNothing().when(remoteExecutionService).uploadOutputs(any(), any(), any());
 
     // act
     try (CacheHandle firstCacheHandle = cache.lookup(firstSpawn, firstPolicy)) {
@@ -1001,11 +1019,14 @@
               .setRunnerName("test")
               .build());
     }
+    Mockito.verify(remoteExecutionService, never()).uploadOutputs(any(), any(), any());
     CacheHandle secondCacheHandle = cache.lookup(secondSpawn, secondPolicy);
 
     // assert
     assertThat(secondCacheHandle.hasResult()).isFalse();
     assertThat(secondCacheHandle.willStore()).isTrue();
+    secondCacheHandle.close();
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
   }
 
   @Test
@@ -1030,7 +1051,14 @@
     // Simulate a very slow upload to the remote cache to ensure that the second spawn is
     // deduplicated rather than a cache hit. This is a slight hack, but also avoid introducing
     // concurrency to this test.
-    Mockito.doNothing().when(remoteExecutionService).uploadOutputs(any(), any(), any());
+    AtomicReference<Runnable> onUploadComplete = new AtomicReference<>();
+    Mockito.doAnswer(
+            invocationOnMock -> {
+              onUploadComplete.set(invocationOnMock.getArgument(2));
+              return null;
+            })
+        .when(remoteExecutionService)
+        .uploadOutputs(any(), any(), any());
 
     // act
     try (CacheHandle firstCacheHandle = cache.lookup(firstSpawn, firstPolicy)) {
@@ -1047,5 +1075,100 @@
     // assert
     assertThat(secondCacheHandle.hasResult()).isFalse();
     assertThat(secondCacheHandle.willStore()).isTrue();
+    onUploadComplete.get().run();
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
+  }
+
+  @Test
+  public void pathMappedActionWithCacheHitRemovesInFlightExecution() throws Exception {
+    // arrange
+    RemoteSpawnCache cache = createRemoteSpawnCache();
+
+    SimpleSpawn spawn = simplePathMappedSpawn("k8-fastbuild");
+    FakeActionInputFileCache fakeFileCache = new FakeActionInputFileCache(execRoot);
+    fakeFileCache.createScratchInput(spawn.getInputFiles().getSingleton(), "xyz");
+    SpawnExecutionContext policy =
+        createSpawnExecutionContext(spawn, execRoot, fakeFileCache, outErr);
+
+    RemoteExecutionService remoteExecutionService = cache.getRemoteExecutionService();
+    Mockito.doReturn(
+            RemoteActionResult.createFromCache(
+                CachedActionResult.remote(ActionResult.getDefaultInstance())))
+        .when(remoteExecutionService)
+        .lookupCache(any());
+    Mockito.doReturn(null).when(remoteExecutionService).downloadOutputs(any(), any());
+
+    // act
+    try (CacheHandle cacheHandle = cache.lookup(spawn, policy)) {
+      checkState(cacheHandle.hasResult());
+    }
+
+    // assert
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
+  }
+
+  @Test
+  public void pathMappedActionNotUploadedRemovesInFlightExecution() throws Exception {
+    // arrange
+    RemoteSpawnCache cache = createRemoteSpawnCache();
+
+    SimpleSpawn spawn = simplePathMappedSpawn("k8-fastbuild");
+    FakeActionInputFileCache fakeFileCache = new FakeActionInputFileCache(execRoot);
+    fakeFileCache.createScratchInput(spawn.getInputFiles().getSingleton(), "xyz");
+    SpawnExecutionContext policy =
+        createSpawnExecutionContext(spawn, execRoot, fakeFileCache, outErr);
+
+    RemoteExecutionService remoteExecutionService = cache.getRemoteExecutionService();
+    Mockito.doCallRealMethod()
+        .when(remoteExecutionService)
+        .commitResultAndDecideWhetherToUpload(any(), any());
+
+    // act
+    try (CacheHandle cacheHandle = cache.lookup(spawn, policy)) {
+      cacheHandle.store(
+          new SpawnResult.Builder()
+              .setExitCode(1)
+              .setStatus(Status.NON_ZERO_EXIT)
+              .setFailureDetail(
+                  FailureDetail.newBuilder()
+                      .setMessage("test spawn failed")
+                      .setSpawn(
+                          FailureDetails.Spawn.newBuilder()
+                              .setCode(FailureDetails.Spawn.Code.NON_ZERO_EXIT))
+                      .build())
+              .setRunnerName("test")
+              .build());
+    }
+
+    // assert
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
+  }
+
+  @Test
+  public void pathMappedActionWithCacheIoExceptionRemovesInFlightExecution() throws Exception {
+    // arrange
+    RemoteSpawnCache cache = createRemoteSpawnCache();
+
+    SimpleSpawn spawn = simplePathMappedSpawn("k8-fastbuild");
+    FakeActionInputFileCache fakeFileCache = new FakeActionInputFileCache(execRoot);
+    fakeFileCache.createScratchInput(spawn.getInputFiles().getSingleton(), "xyz");
+    SpawnExecutionContext policy =
+        createSpawnExecutionContext(spawn, execRoot, fakeFileCache, outErr);
+
+    RemoteExecutionService remoteExecutionService = cache.getRemoteExecutionService();
+    Mockito.doReturn(
+            RemoteActionResult.createFromCache(
+                CachedActionResult.remote(ActionResult.getDefaultInstance())))
+        .when(remoteExecutionService)
+        .lookupCache(any());
+    Mockito.doThrow(new IOException()).when(remoteExecutionService).downloadOutputs(any(), any());
+
+    // act
+    try (CacheHandle cacheHandle = cache.lookup(spawn, policy)) {
+      checkState(!cacheHandle.hasResult());
+    }
+
+    // assert
+    assertThat(cache.getInFlightExecutionsSize()).isEqualTo(0);
   }
 }