getFromFuture will cancel the future by default on InterruptedException
Fixes #11339.
Closes #12453.
PiperOrigin-RevId: 341993215
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
index 7f3f83d..93ec205 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -34,6 +34,7 @@
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
import com.google.devtools.build.lib.sandbox.SandboxHelpers;
import com.google.devtools.build.lib.vfs.Path;
import io.grpc.Context;
@@ -42,7 +43,6 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.ExecutionException;
import javax.annotation.concurrent.GuardedBy;
/**
@@ -145,14 +145,7 @@
void downloadFile(Path path, FileArtifactValue metadata)
throws IOException, InterruptedException {
- try {
- downloadFileAsync(path, metadata).get();
- } catch (ExecutionException e) {
- if (e.getCause() instanceof IOException) {
- throw (IOException) e.getCause();
- }
- throw new IOException(e.getCause());
- }
+ Utils.getFromFuture(downloadFileAsync(path, metadata));
}
private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
index 511c5bd..5fc3b3a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
@@ -218,7 +218,7 @@
try {
if (interruptedException == null) {
// Wait for all transfers to finish.
- getFromFuture(transfer);
+ getFromFuture(transfer, cancelRemainingOnInterrupt);
} else {
transfer.cancel(true);
}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
index 9d07d7a..86f21a5 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
@@ -57,9 +57,22 @@
/**
* Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link
* Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}.
+ *
+ * <p>Cancel the future on {@link InterruptedException}
*/
public static <T> T getFromFuture(ListenableFuture<T> f)
throws IOException, InterruptedException {
+ return getFromFuture(f, /* cancelOnInterrupt */ true);
+ }
+
+ /**
+ * Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link
+ * Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}.
+ *
+ * @param cancelOnInterrupt cancel the future on {@link InterruptedException} if {@code true}.
+ */
+ public static <T> T getFromFuture(ListenableFuture<T> f, boolean cancelOnInterrupt)
+ throws IOException, InterruptedException {
try {
return f.get();
} catch (ExecutionException e) {
@@ -74,6 +87,11 @@
throw (RuntimeException) cause;
}
throw new IOException(cause);
+ } catch (InterruptedException e) {
+ if (cancelOnInterrupt) {
+ f.cancel(true);
+ }
+ throw e;
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
index 5f991de..4dded52 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
@@ -15,6 +15,9 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.RequestMetadata;
@@ -22,6 +25,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.hash.HashCode;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.SettableFuture;
import com.google.devtools.build.lib.actions.ActionInput;
import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.actions.ArtifactRoot;
@@ -46,6 +51,8 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -186,6 +193,54 @@
assertThat(a1.getPath().isWritable()).isTrue();
}
+ @Test
+ public void testDownloadFile_onInterrupt_deletePartialDownloadedFile() throws Exception {
+ Semaphore startSemaphore = new Semaphore(0);
+ Semaphore endSemaphore = new Semaphore(0);
+ Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+ Map<Digest, ByteString> cacheEntries = new HashMap<>();
+ Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
+ RemoteCache remoteCache = mock(RemoteCache.class);
+ when(remoteCache.downloadFile(any(), any()))
+ .thenAnswer(
+ invocation -> {
+ Path path = invocation.getArgument(0);
+ Digest digest = invocation.getArgument(1);
+ ByteString content = cacheEntries.get(digest);
+ if (content == null) {
+ return Futures.immediateFailedFuture(new IOException("Not found"));
+ }
+ content.writeTo(path.getOutputStream());
+
+ startSemaphore.release();
+ return SettableFuture
+ .create(); // A future that never complete so we can interrupt later
+ });
+ RemoteActionInputFetcher actionInputFetcher =
+ new RemoteActionInputFetcher(remoteCache, execRoot, RequestMetadata.getDefaultInstance());
+
+ AtomicBoolean interrupted = new AtomicBoolean(false);
+ Thread t =
+ new Thread(
+ () -> {
+ try {
+ actionInputFetcher.downloadFile(a1.getPath(), metadata.get(a1));
+ } catch (IOException ignored) {
+ interrupted.set(false);
+ } catch (InterruptedException e) {
+ interrupted.set(true);
+ }
+ endSemaphore.release();
+ });
+ t.start();
+ startSemaphore.acquire();
+ t.interrupt();
+ endSemaphore.acquire();
+
+ assertThat(interrupted.get()).isTrue();
+ assertThat(a1.getPath().exists()).isFalse();
+ }
+
private Artifact createRemoteArtifact(
String pathFragment,
String contents,