Remote: Fix crashes by InterruptedException when dynamic execution is enabled. (#15091)

Fixes #14433.

The root cause is, inside `RemoteExecutionCache`, the result of `FindMissingDigests` is shared with other threads without considering error handling. For example, if there are two or more threads uploading the same input and one thread got interrupted when waiting for the result of `FindMissingDigests` call, the call is cancelled and others threads still waiting for the upload will receive upload error due to the cancellation which is wrong.

This PR fixes this by effectively applying reference count to the result of `FindMissingDigests` call so that if one thread got interrupted, as long as there are other threads depending on the result, the call won't be cancelled and the upload can continue.

Closes #15001.

PiperOrigin-RevId: 436180205
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java
index 229163e..5474f88 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java
@@ -13,30 +13,39 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote;
 
-import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+import static com.google.devtools.build.lib.remote.util.RxFutures.toCompletable;
+import static com.google.devtools.build.lib.remote.util.RxFutures.toSingle;
+import static com.google.devtools.build.lib.remote.util.RxUtils.mergeBulkTransfer;
+import static com.google.devtools.build.lib.remote.util.RxUtils.toTransferResult;
 import static java.lang.String.format;
 
 import build.bazel.remote.execution.v2.Digest;
 import build.bazel.remote.execution.v2.Directory;
+import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.MoreExecutors;
 import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
 import com.google.devtools.build.lib.remote.common.RemoteCacheClient;
 import com.google.devtools.build.lib.remote.merkletree.MerkleTree;
 import com.google.devtools.build.lib.remote.merkletree.MerkleTree.PathOrBytes;
 import com.google.devtools.build.lib.remote.options.RemoteOptions;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
-import com.google.devtools.build.lib.remote.util.RxFutures;
+import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult;
 import com.google.protobuf.Message;
 import io.reactivex.rxjava3.core.Completable;
+import io.reactivex.rxjava3.core.Flowable;
+import io.reactivex.rxjava3.core.Single;
 import io.reactivex.rxjava3.subjects.AsyncSubject;
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.HashSet;
 import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.concurrent.GuardedBy;
 
 /** A {@link RemoteCache} with additional functionality needed for remote execution. */
 public class RemoteExecutionCache extends RemoteCache {
@@ -72,62 +81,58 @@
             .addAll(additionalInputs.keySet())
             .build();
 
-    // Collect digests that are not being or already uploaded
-    ConcurrentHashMap<Digest, AsyncSubject<Boolean>> missingDigestSubjects =
-        new ConcurrentHashMap<>();
-
-    List<ListenableFuture<Void>> uploadFutures = new ArrayList<>();
-    for (Digest digest : allDigests) {
-      Completable upload =
-          casUploadCache.execute(
-              digest,
-              Completable.defer(
-                  () -> {
-                    // The digest hasn't been processed, add it to the collection which will be used
-                    // later for findMissingDigests call
-                    AsyncSubject<Boolean> missingDigestSubject = AsyncSubject.create();
-                    missingDigestSubjects.put(digest, missingDigestSubject);
-
-                    return missingDigestSubject.flatMapCompletable(
-                        missing -> {
-                          if (!missing) {
-                            return Completable.complete();
-                          }
-                          return RxFutures.toCompletable(
-                              () -> uploadBlob(context, digest, merkleTree, additionalInputs),
-                              MoreExecutors.directExecutor());
-                        });
-                  }),
-              force);
-      uploadFutures.add(RxFutures.toListenableFuture(upload));
+    if (allDigests.isEmpty()) {
+      return;
     }
 
-    ImmutableSet<Digest> missingDigests;
-    try {
-      missingDigests = getFromFuture(findMissingDigests(context, missingDigestSubjects.keySet()));
-    } catch (IOException | InterruptedException e) {
-      for (Map.Entry<Digest, AsyncSubject<Boolean>> entry : missingDigestSubjects.entrySet()) {
-        entry.getValue().onError(e);
-      }
+    MissingDigestFinder missingDigestFinder = new MissingDigestFinder(context, allDigests.size());
+    Flowable<TransferResult> uploads =
+        Flowable.fromIterable(allDigests)
+            .flatMapSingle(
+                digest ->
+                    uploadBlobIfMissing(
+                        context, merkleTree, additionalInputs, force, missingDigestFinder, digest));
 
-      if (e instanceof InterruptedException) {
-        Thread.currentThread().interrupt();
+    try {
+      mergeBulkTransfer(uploads).blockingAwait();
+    } catch (RuntimeException e) {
+      Throwable cause = e.getCause();
+      if (cause != null) {
+        Throwables.throwIfInstanceOf(cause, InterruptedException.class);
+        Throwables.throwIfInstanceOf(cause, IOException.class);
       }
       throw e;
     }
+  }
 
-    for (Map.Entry<Digest, AsyncSubject<Boolean>> entry : missingDigestSubjects.entrySet()) {
-      AsyncSubject<Boolean> missingSubject = entry.getValue();
-      if (missingDigests.contains(entry.getKey())) {
-        missingSubject.onNext(true);
-      } else {
-        // The digest is already existed in the remote cache, skip the upload.
-        missingSubject.onNext(false);
-      }
-      missingSubject.onComplete();
-    }
-
-    waitForBulkTransfer(uploadFutures, /* cancelRemainingOnInterrupt=*/ false);
+  private Single<TransferResult> uploadBlobIfMissing(
+      RemoteActionExecutionContext context,
+      MerkleTree merkleTree,
+      Map<Digest, Message> additionalInputs,
+      boolean force,
+      MissingDigestFinder missingDigestFinder,
+      Digest digest) {
+    Completable upload =
+        casUploadCache.execute(
+            digest,
+            Completable.defer(
+                () ->
+                    // Only reach here if the digest is missing and is not being uploaded.
+                    missingDigestFinder
+                        .registerAndCount(digest)
+                        .flatMapCompletable(
+                            missingDigests -> {
+                              if (missingDigests.contains(digest)) {
+                                return toCompletable(
+                                    () -> uploadBlob(context, digest, merkleTree, additionalInputs),
+                                    directExecutor());
+                              } else {
+                                return Completable.complete();
+                              }
+                            })),
+            /* onIgnored= */ missingDigestFinder::count,
+            force);
+    return toTransferResult(upload);
   }
 
   private ListenableFuture<Void> uploadBlob(
@@ -159,4 +164,93 @@
                 "findMissingDigests returned a missing digest that has not been requested: %s",
                 digest)));
   }
+
+  /**
+   * A missing digest finder that initiates the request when the internal counter reaches an
+   * expected count.
+   */
+  class MissingDigestFinder {
+    private final int expectedCount;
+
+    private final AsyncSubject<ImmutableSet<Digest>> digestsSubject;
+    private final Single<ImmutableSet<Digest>> resultSingle;
+
+    @GuardedBy("this")
+    private final Set<Digest> digests;
+
+    @GuardedBy("this")
+    private int currentCount = 0;
+
+    MissingDigestFinder(RemoteActionExecutionContext context, int expectedCount) {
+      checkArgument(expectedCount > 0, "expectedCount should be greater than 0");
+      this.expectedCount = expectedCount;
+      this.digestsSubject = AsyncSubject.create();
+      this.digests = new HashSet<>();
+
+      AtomicBoolean findMissingDigestsCalled = new AtomicBoolean(false);
+      this.resultSingle =
+          Single.fromObservable(
+              digestsSubject
+                  .flatMapSingle(
+                      digests -> {
+                        boolean wasCalled = findMissingDigestsCalled.getAndSet(true);
+                        // Make sure we don't have re-subscription caused by refCount() below.
+                        checkState(!wasCalled, "FindMissingDigests is called more than once");
+                        return toSingle(
+                            () -> findMissingDigests(context, digests), directExecutor());
+                      })
+                  // Use replay here because we could have a race condition that downstream hasn't
+                  // been added to the subscription list (to receive the upstream result) while
+                  // upstream is completed.
+                  .replay(1)
+                  .refCount());
+    }
+
+    /**
+     * Register the {@code digest} and increase the counter.
+     *
+     * <p>Returned Single cannot be subscribed more than once.
+     *
+     * @return Single that emits the result of the {@code FindMissingDigest} request.
+     */
+    Single<ImmutableSet<Digest>> registerAndCount(Digest digest) {
+      AtomicBoolean subscribed = new AtomicBoolean(false);
+      // count() will potentially trigger the findMissingDigests call. Adding and counting before
+      // returning the Single could introduce a race that the result of findMissingDigests is
+      // available but the consumer doesn't get it because it hasn't subscribed the returned
+      // Single. In this case, it subscribes after upstream is completed resulting a re-run of
+      // findMissingDigests (due to refCount()).
+      //
+      // Calling count() inside doOnSubscribe to ensure the consumer already subscribed to the
+      // returned Single to avoid a re-execution of findMissingDigests.
+      return resultSingle.doOnSubscribe(
+          d -> {
+            boolean wasSubscribed = subscribed.getAndSet(true);
+            checkState(!wasSubscribed, "Single is subscribed more than once");
+            synchronized (this) {
+              digests.add(digest);
+            }
+            count();
+          });
+    }
+
+    /** Increase the counter. */
+    void count() {
+      ImmutableSet<Digest> digestsResult = null;
+
+      synchronized (this) {
+        if (currentCount < expectedCount) {
+          currentCount++;
+          if (currentCount == expectedCount) {
+            digestsResult = ImmutableSet.copyOf(digests);
+          }
+        }
+      }
+
+      if (digestsResult != null) {
+        digestsSubject.onNext(digestsResult);
+        digestsSubject.onComplete();
+      }
+    }
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
index 8fb6f4c..31369ef 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java
@@ -24,6 +24,7 @@
 import io.reactivex.rxjava3.core.Single;
 import io.reactivex.rxjava3.core.SingleObserver;
 import io.reactivex.rxjava3.disposables.Disposable;
+import io.reactivex.rxjava3.functions.Action;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -256,14 +257,25 @@
   /**
    * Executes a task.
    *
+   * @see #execute(Object, Single, Action, boolean).
+   */
+  public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
+    return execute(key, task, () -> {}, force);
+  }
+
+  /**
+   * Executes a task. If the task has already finished, this execution of the task is ignored unless
+   * `force` is true. If the task is in progress this execution of the task is always ignored.
+   *
    * <p>If the cache is already shutdown, a {@link CancellationException} will be emitted.
    *
    * @param key identifies the task.
+   * @param onIgnored callback called when provided task is ignored.
    * @param force re-execute a finished task if set to {@code true}.
    * @return a {@link Single} which turns to completed once the task is finished or propagates the
    *     error if any.
    */
-  public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
+  public Single<ValueT> execute(KeyT key, Single<ValueT> task, Action onIgnored, boolean force) {
     return Single.create(
         emitter -> {
           synchronized (lock) {
@@ -273,14 +285,20 @@
             }
 
             if (!force && finished.containsKey(key)) {
+              onIgnored.run();
               emitter.onSuccess(finished.get(key));
               return;
             }
 
             finished.remove(key);
 
-            Execution execution =
-                inProgress.computeIfAbsent(key, ignoredKey -> new Execution(key, task));
+            Execution execution = inProgress.get(key);
+            if (execution != null) {
+              onIgnored.run();
+            } else {
+              execution = new Execution(key, task);
+              inProgress.put(key, execution);
+            }
 
             // We must subscribe the execution within the scope of lock to avoid race condition
             // that:
@@ -425,10 +443,15 @@
           cache.executeIfNot(key, task.toSingleDefault(Optional.empty())));
     }
 
-    /** Same as {@link AsyncTaskCache#executeIfNot} but operates on {@link Completable}. */
+    /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */
     public Completable execute(KeyT key, Completable task, boolean force) {
+      return execute(key, task, () -> {}, force);
+    }
+
+    /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */
+    public Completable execute(KeyT key, Completable task, Action onIgnored, boolean force) {
       return Completable.fromSingle(
-          cache.execute(key, task.toSingleDefault(Optional.empty()), force));
+          cache.execute(key, task.toSingleDefault(Optional.empty()), onIgnored, force));
     }
 
     /** Returns a set of keys for tasks which is finished. */
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java
index b885e3b..18b5a68 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java
@@ -51,6 +51,7 @@
 import com.google.common.collect.ImmutableSet;
 import com.google.common.eventbus.EventBus;
 import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.SettableFuture;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.ActionInputHelper;
 import com.google.devtools.build.lib.actions.ActionUploadFinishedEvent;
@@ -109,6 +110,7 @@
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Random;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Semaphore;
@@ -1433,19 +1435,18 @@
     ActionInput input = ActionInputHelper.fromPath("inputs/foo");
     Digest inputDigest = fakeFileCache.createScratchInput(input, "input-foo");
     RemoteExecutionService service = newRemoteExecutionService();
+    Spawn spawn =
+        newSpawn(
+            ImmutableMap.of(),
+            ImmutableSet.of(),
+            NestedSetBuilder.create(Order.STABLE_ORDER, input));
+    FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn);
+    RemoteAction action = service.buildRemoteAction(spawn, context);
 
     for (int i = 0; i < taskCount; ++i) {
       executorService.execute(
           () -> {
             try {
-              Spawn spawn =
-                  newSpawn(
-                      ImmutableMap.of(),
-                      ImmutableSet.of(),
-                      NestedSetBuilder.create(Order.STABLE_ORDER, input));
-              FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn);
-              RemoteAction action = service.buildRemoteAction(spawn, context);
-
               service.uploadInputsIfNotPresent(action, /*force=*/ false);
             } catch (Throwable e) {
               if (e instanceof InterruptedException) {
@@ -1467,6 +1468,72 @@
   }
 
   @Test
+  public void uploadInputsIfNotPresent_sameInputs_interruptOne_keepOthers() throws Exception {
+    int taskCount = 100;
+    ExecutorService executorService = Executors.newFixedThreadPool(taskCount);
+    AtomicReference<Throwable> error = new AtomicReference<>(null);
+    Semaphore semaphore = new Semaphore(0);
+    ActionInput input = ActionInputHelper.fromPath("inputs/foo");
+    fakeFileCache.createScratchInput(input, "input-foo");
+    RemoteExecutionService service = newRemoteExecutionService();
+    Spawn spawn =
+        newSpawn(
+            ImmutableMap.of(),
+            ImmutableSet.of(),
+            NestedSetBuilder.create(Order.STABLE_ORDER, input));
+    FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn);
+    RemoteAction action = service.buildRemoteAction(spawn, context);
+    Random random = new Random();
+
+    for (int i = 0; i < taskCount; ++i) {
+      boolean shouldInterrupt = random.nextBoolean();
+      executorService.execute(
+          () -> {
+            try {
+              if (shouldInterrupt) {
+                Thread.currentThread().interrupt();
+              }
+              service.uploadInputsIfNotPresent(action, /*force=*/ false);
+            } catch (Throwable e) {
+              if (!(shouldInterrupt && e instanceof InterruptedException)) {
+                error.set(e);
+              }
+            } finally {
+              semaphore.release();
+            }
+          });
+    }
+    semaphore.acquire(taskCount);
+
+    assertThat(error.get()).isNull();
+  }
+
+  @Test
+  public void uploadInputsIfNotPresent_interrupted_requestCancelled() throws Exception {
+    SettableFuture<ImmutableSet<Digest>> future = SettableFuture.create();
+    doReturn(future).when(cache).findMissingDigests(any(), any());
+    ActionInput input = ActionInputHelper.fromPath("inputs/foo");
+    fakeFileCache.createScratchInput(input, "input-foo");
+    RemoteExecutionService service = newRemoteExecutionService();
+    Spawn spawn =
+        newSpawn(
+            ImmutableMap.of(),
+            ImmutableSet.of(),
+            NestedSetBuilder.create(Order.STABLE_ORDER, input));
+    FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn);
+    RemoteAction action = service.buildRemoteAction(spawn, context);
+
+    try {
+      Thread.currentThread().interrupt();
+      service.uploadInputsIfNotPresent(action, /*force=*/ false);
+    } catch (InterruptedException ignored) {
+      // Intentionally left empty
+    }
+
+    assertThat(future.isCancelled()).isTrue();
+  }
+
+  @Test
   public void buildMerkleTree_withMemoization_works() throws Exception {
     // Test that Merkle tree building can be memoized.
 
diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java
index c26629f..8925640 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java
@@ -19,6 +19,8 @@
 import com.google.common.io.ByteStreams;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
 import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
 import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
 import com.google.devtools.build.lib.remote.common.RemoteCacheClient;
@@ -31,12 +33,15 @@
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
 /** A {@link RemoteCacheClient} that stores its contents in memory. */
 public final class InMemoryCacheClient implements RemoteCacheClient {
 
+  private final ListeningExecutorService executorService =
+      MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(100));
   private final ConcurrentMap<Digest, Exception> downloadFailures = new ConcurrentHashMap<>();
   private final ConcurrentMap<ActionKey, ActionResult> ac = new ConcurrentHashMap<>();
   private final ConcurrentMap<Digest, byte[]> cas;
@@ -142,16 +147,19 @@
   @Override
   public ListenableFuture<ImmutableSet<Digest>> findMissingDigests(
       RemoteActionExecutionContext context, Iterable<Digest> digests) {
-    ImmutableSet.Builder<Digest> missingBuilder = ImmutableSet.builder();
-    for (Digest digest : digests) {
-      numFindMissingDigests
-          .computeIfAbsent(digest, (key) -> new AtomicInteger(0))
-          .incrementAndGet();
-      if (!cas.containsKey(digest)) {
-        missingBuilder.add(digest);
-      }
-    }
-    return Futures.immediateFuture(missingBuilder.build());
+    return executorService.submit(
+        () -> {
+          ImmutableSet.Builder<Digest> missingBuilder = ImmutableSet.builder();
+          for (Digest digest : digests) {
+            numFindMissingDigests
+                .computeIfAbsent(digest, (key) -> new AtomicInteger(0))
+                .incrementAndGet();
+            if (!cas.containsKey(digest)) {
+              missingBuilder.add(digest);
+            }
+          }
+          return missingBuilder.build();
+        });
   }
 
   @Override