getFromFuture will cancel the future by default on InterruptedException

Fixes #11339.

Closes #12453.

PiperOrigin-RevId: 341993215
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
index 7f3f83d..93ec205 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -34,6 +34,7 @@
 import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers;
 import com.google.devtools.build.lib.vfs.Path;
 import io.grpc.Context;
@@ -42,7 +43,6 @@
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.ExecutionException;
 import javax.annotation.concurrent.GuardedBy;
 
 /**
@@ -145,14 +145,7 @@
 
   void downloadFile(Path path, FileArtifactValue metadata)
       throws IOException, InterruptedException {
-    try {
-      downloadFileAsync(path, metadata).get();
-    } catch (ExecutionException e) {
-      if (e.getCause() instanceof IOException) {
-        throw (IOException) e.getCause();
-      }
-      throw new IOException(e.getCause());
-    }
+    Utils.getFromFuture(downloadFileAsync(path, metadata));
   }
 
   private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
index 511c5bd..5fc3b3a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
@@ -218,7 +218,7 @@
       try {
         if (interruptedException == null) {
           // Wait for all transfers to finish.
-          getFromFuture(transfer);
+          getFromFuture(transfer, cancelRemainingOnInterrupt);
         } else {
           transfer.cancel(true);
         }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
index 9d07d7a..86f21a5 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
@@ -57,9 +57,22 @@
   /**
    * Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link
    * Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}.
+   *
+   * <p>Cancel the future on {@link InterruptedException}
    */
   public static <T> T getFromFuture(ListenableFuture<T> f)
       throws IOException, InterruptedException {
+    return getFromFuture(f, /* cancelOnInterrupt */ true);
+  }
+
+  /**
+   * Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link
+   * Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}.
+   *
+   * @param cancelOnInterrupt cancel the future on {@link InterruptedException} if {@code true}.
+   */
+  public static <T> T getFromFuture(ListenableFuture<T> f, boolean cancelOnInterrupt)
+      throws IOException, InterruptedException {
     try {
       return f.get();
     } catch (ExecutionException e) {
@@ -74,6 +87,11 @@
         throw (RuntimeException) cause;
       }
       throw new IOException(cause);
+    } catch (InterruptedException e) {
+      if (cancelOnInterrupt) {
+        f.cancel(true);
+      }
+      throw e;
     }
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
index 5f991de..4dded52 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
@@ -15,6 +15,9 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 import build.bazel.remote.execution.v2.Digest;
 import build.bazel.remote.execution.v2.RequestMetadata;
@@ -22,6 +25,8 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
 import com.google.common.hash.HashCode;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.SettableFuture;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.actions.ArtifactRoot;
@@ -46,6 +51,8 @@
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicBoolean;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -186,6 +193,54 @@
     assertThat(a1.getPath().isWritable()).isTrue();
   }
 
+  @Test
+  public void testDownloadFile_onInterrupt_deletePartialDownloadedFile() throws Exception {
+    Semaphore startSemaphore = new Semaphore(0);
+    Semaphore endSemaphore = new Semaphore(0);
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<Digest, ByteString> cacheEntries = new HashMap<>();
+    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
+    RemoteCache remoteCache = mock(RemoteCache.class);
+    when(remoteCache.downloadFile(any(), any()))
+        .thenAnswer(
+            invocation -> {
+              Path path = invocation.getArgument(0);
+              Digest digest = invocation.getArgument(1);
+              ByteString content = cacheEntries.get(digest);
+              if (content == null) {
+                return Futures.immediateFailedFuture(new IOException("Not found"));
+              }
+              content.writeTo(path.getOutputStream());
+
+              startSemaphore.release();
+              return SettableFuture
+                  .create(); // A future that never complete so we can interrupt later
+            });
+    RemoteActionInputFetcher actionInputFetcher =
+        new RemoteActionInputFetcher(remoteCache, execRoot, RequestMetadata.getDefaultInstance());
+
+    AtomicBoolean interrupted = new AtomicBoolean(false);
+    Thread t =
+        new Thread(
+            () -> {
+              try {
+                actionInputFetcher.downloadFile(a1.getPath(), metadata.get(a1));
+              } catch (IOException ignored) {
+                interrupted.set(false);
+              } catch (InterruptedException e) {
+                interrupted.set(true);
+              }
+              endSemaphore.release();
+            });
+    t.start();
+    startSemaphore.acquire();
+    t.interrupt();
+    endSemaphore.acquire();
+
+    assertThat(interrupted.get()).isTrue();
+    assertThat(a1.getPath().exists()).isFalse();
+  }
+
   private Artifact createRemoteArtifact(
       String pathFragment,
       String contents,