getFromFuture will cancel the future by default on InterruptedException Fixes #11339. Closes #12453. PiperOrigin-RevId: 341993215
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java index 7f3f83d..93ec205 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -34,6 +34,7 @@ import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; +import com.google.devtools.build.lib.remote.util.Utils; import com.google.devtools.build.lib.sandbox.SandboxHelpers; import com.google.devtools.build.lib.vfs.Path; import io.grpc.Context; @@ -42,7 +43,6 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.concurrent.ExecutionException; import javax.annotation.concurrent.GuardedBy; /** @@ -145,14 +145,7 @@ void downloadFile(Path path, FileArtifactValue metadata) throws IOException, InterruptedException { - try { - downloadFileAsync(path, metadata).get(); - } catch (ExecutionException e) { - if (e.getCause() instanceof IOException) { - throw (IOException) e.getCause(); - } - throw new IOException(e.getCause()); - } + Utils.getFromFuture(downloadFileAsync(path, metadata)); } private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java index 511c5bd..5fc3b3a 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
@@ -218,7 +218,7 @@ try { if (interruptedException == null) { // Wait for all transfers to finish. - getFromFuture(transfer); + getFromFuture(transfer, cancelRemainingOnInterrupt); } else { transfer.cancel(true); }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java index 9d07d7a..86f21a5 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
@@ -57,9 +57,22 @@ /** * Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link * Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}. + * + * <p>Cancel the future on {@link InterruptedException} */ public static <T> T getFromFuture(ListenableFuture<T> f) throws IOException, InterruptedException { + return getFromFuture(f, /* cancelOnInterrupt */ true); + } + + /** + * Returns the result of a {@link ListenableFuture} if successful, or throws any checked {@link + * Exception} directly if it's an {@link IOException} or else wraps it in an {@link IOException}. + * + * @param cancelOnInterrupt cancel the future on {@link InterruptedException} if {@code true}. + */ + public static <T> T getFromFuture(ListenableFuture<T> f, boolean cancelOnInterrupt) + throws IOException, InterruptedException { try { return f.get(); } catch (ExecutionException e) { @@ -74,6 +87,11 @@ throw (RuntimeException) cause; } throw new IOException(cause); + } catch (InterruptedException e) { + if (cancelOnInterrupt) { + f.cancel(true); + } + throw e; } }
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java index 5f991de..4dded52 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
@@ -15,6 +15,9 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.RequestMetadata; @@ -22,6 +25,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.common.hash.HashCode; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.actions.ActionInput; import com.google.devtools.build.lib.actions.Artifact; import com.google.devtools.build.lib.actions.ArtifactRoot; @@ -46,6 +51,8 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -186,6 +193,54 @@ assertThat(a1.getPath().isWritable()).isTrue(); } + @Test + public void testDownloadFile_onInterrupt_deletePartialDownloadedFile() throws Exception { + Semaphore startSemaphore = new Semaphore(0); + Semaphore endSemaphore = new Semaphore(0); + Map<ActionInput, FileArtifactValue> metadata = new HashMap<>(); + Map<Digest, ByteString> cacheEntries = new HashMap<>(); + Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries); + RemoteCache remoteCache = mock(RemoteCache.class); + when(remoteCache.downloadFile(any(), any())) + .thenAnswer( + invocation -> { + Path path = invocation.getArgument(0); + Digest digest = invocation.getArgument(1); + ByteString content = cacheEntries.get(digest); + if (content == null) { + return Futures.immediateFailedFuture(new IOException("Not found")); + } + content.writeTo(path.getOutputStream()); + + startSemaphore.release(); + return SettableFuture + .create(); // A future that never complete so we can interrupt later + }); + RemoteActionInputFetcher actionInputFetcher = + new RemoteActionInputFetcher(remoteCache, execRoot, RequestMetadata.getDefaultInstance()); + + AtomicBoolean interrupted = new AtomicBoolean(false); + Thread t = + new Thread( + () -> { + try { + actionInputFetcher.downloadFile(a1.getPath(), metadata.get(a1)); + } catch (IOException ignored) { + interrupted.set(false); + } catch (InterruptedException e) { + interrupted.set(true); + } + endSemaphore.release(); + }); + t.start(); + startSemaphore.acquire(); + t.interrupt(); + endSemaphore.acquire(); + + assertThat(interrupted.get()).isTrue(); + assertThat(a1.getPath().exists()).isFalse(); + } + private Artifact createRemoteArtifact( String pathFragment, String contents,