Remote: Fix the issue that partial downloaded inputs are not deleted if the request is cancelled.
Following https://github.com/bazelbuild/bazel/commit/280ef6915b0f507218a073974825d6aa7effddee, this change fixes the issue that partial downloaded files are not deleted. The root cause is, even we cancel the download futures inside AbstractActionInputPrefetcher and then delete the files, the actual downloads inside GrpcCacheClient are not cancelled so it is still writing to the path.
PiperOrigin-RevId: 447970542
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index 5dd7dc0..6f7c442 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -14,7 +14,9 @@
package com.google.devtools.build.lib.remote;
+import static com.google.bytestream.ByteStreamGrpc.getReadMethod;
import static com.google.common.base.Strings.isNullOrEmpty;
+import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
import build.bazel.remote.execution.v2.ActionCacheGrpc;
import build.bazel.remote.execution.v2.ActionCacheGrpc.ActionCacheFutureStub;
@@ -41,7 +43,6 @@
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
-import com.google.common.util.concurrent.SettableFuture;
import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
@@ -50,6 +51,7 @@
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
import com.google.devtools.build.lib.remote.common.RemoteCacheClient;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
+import com.google.devtools.build.lib.remote.util.CompletableFuture;
import com.google.devtools.build.lib.remote.util.DigestOutputStream;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
@@ -58,6 +60,7 @@
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
+import io.grpc.ClientCall;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
@@ -363,81 +366,84 @@
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
- SettableFuture<Long> future = SettableFuture.create();
+ CompletableFuture<Long> future = CompletableFuture.create();
OutputStream out;
try {
out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
- bsAsyncStub(context, channel)
- .read(
- ReadRequest.newBuilder()
- .setResourceName(resourceName)
- .setReadOffset(rawOut.getCount())
- .build(),
- new StreamObserver<ReadResponse>() {
+ ByteStreamStub stub = bsAsyncStub(context, channel);
+ ClientCall<ReadRequest, ReadResponse> clientCall =
+ stub.getChannel().newCall(getReadMethod(), stub.getCallOptions());
+ future.setCancelCallback(() -> clientCall.cancel("Cancelled", /* cause= */ null));
+ asyncServerStreamingCall(
+ clientCall,
+ ReadRequest.newBuilder()
+ .setResourceName(resourceName)
+ .setReadOffset(rawOut.getCount())
+ .build(),
+ new StreamObserver<ReadResponse>() {
- @Override
- public void onNext(ReadResponse readResponse) {
- ByteString data = readResponse.getData();
- try {
- data.writeTo(out);
- } catch (IOException e) {
- // Cancel the call.
- throw new RuntimeException(e);
- }
- // reset the stall backoff because we've made progress or been kept alive
- progressiveBackoff.reset();
- }
+ @Override
+ public void onNext(ReadResponse readResponse) {
+ ByteString data = readResponse.getData();
+ try {
+ data.writeTo(out);
+ } catch (IOException e) {
+ // Cancel the call.
+ throw new RuntimeException(e);
+ }
+ // reset the stall backoff because we've made progress or been kept alive
+ progressiveBackoff.reset();
+ }
- @Override
- public void onError(Throwable t) {
- if (rawOut.getCount() == digest.getSizeBytes()) {
- // If the file was fully downloaded, it doesn't matter if there was an error at
- // the end of the stream.
- logger.atInfo().withCause(t).log(
- "ignoring error because file was fully received");
- onCompleted();
- return;
- }
- releaseOut();
- Status status = Status.fromThrowable(t);
- if (status.getCode() == Status.Code.NOT_FOUND) {
- future.setException(new CacheNotFoundException(digest));
- } else {
- future.setException(t);
- }
- }
+ @Override
+ public void onError(Throwable t) {
+ if (rawOut.getCount() == digest.getSizeBytes()) {
+ // If the file was fully downloaded, it doesn't matter if there was an error at
+ // the end of the stream.
+ logger.atInfo().withCause(t).log("ignoring error because file was fully received");
+ onCompleted();
+ return;
+ }
+ releaseOut();
+ Status status = Status.fromThrowable(t);
+ if (status.getCode() == Status.Code.NOT_FOUND) {
+ future.setException(new CacheNotFoundException(digest));
+ } else {
+ future.setException(t);
+ }
+ }
- @Override
- public void onCompleted() {
- try {
- if (digestSupplier != null) {
- Utils.verifyBlobContents(digest, digestSupplier.get());
- }
- out.flush();
- future.set(rawOut.getCount());
- } catch (IOException e) {
- future.setException(e);
- } catch (RuntimeException e) {
- logger.atWarning().withCause(e).log("Unexpected exception");
- future.setException(e);
- } finally {
- releaseOut();
- }
+ @Override
+ public void onCompleted() {
+ try {
+ if (digestSupplier != null) {
+ Utils.verifyBlobContents(digest, digestSupplier.get());
}
+ out.flush();
+ future.set(rawOut.getCount());
+ } catch (IOException e) {
+ future.setException(e);
+ } catch (RuntimeException e) {
+ logger.atWarning().withCause(e).log("Unexpected exception");
+ future.setException(e);
+ } finally {
+ releaseOut();
+ }
+ }
- private void releaseOut() {
- if (out instanceof ZstdDecompressingOutputStream) {
- try {
- ((ZstdDecompressingOutputStream) out).closeShallow();
- } catch (IOException e) {
- logger.atWarning().withCause(e).log("failed to cleanly close output stream");
- }
- }
+ private void releaseOut() {
+ if (out instanceof ZstdDecompressingOutputStream) {
+ try {
+ ((ZstdDecompressingOutputStream) out).closeShallow();
+ } catch (IOException e) {
+ logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
- });
+ }
+ }
+ });
return future;
}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
index ef97a4bc..504c186 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteCache.java
@@ -41,6 +41,7 @@
import com.google.devtools.build.lib.remote.common.RemoteCacheClient.CachedActionResult;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.build.lib.remote.util.AsyncTaskCache;
+import com.google.devtools.build.lib.remote.util.CompletableFuture;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.RxFutures;
import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
@@ -87,9 +88,7 @@
protected final DigestUtil digestUtil;
public RemoteCache(
- RemoteCacheClient cacheProtocol,
- RemoteOptions options,
- DigestUtil digestUtil) {
+ RemoteCacheClient cacheProtocol, RemoteOptions options, DigestUtil digestUtil) {
this.cacheProtocol = cacheProtocol;
this.options = options;
this.digestUtil = digestUtil;
@@ -332,8 +331,9 @@
reporter.started();
OutputStream out = new ReportingOutputStream(new LazyFileOutputStream(path), reporter);
- SettableFuture<Void> outerF = SettableFuture.create();
+ CompletableFuture<Void> outerF = CompletableFuture.create();
ListenableFuture<Void> f = cacheProtocol.downloadBlob(context, digest, out);
+ outerF.setCancelCallback(() -> f.cancel(/* mayInterruptIfRunning= */ true));
Futures.addCallback(
f,
new FutureCallback<Void>() {
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/CompletableFuture.java b/src/main/java/com/google/devtools/build/lib/remote/util/CompletableFuture.java
new file mode 100644
index 0000000..6a6566d
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/CompletableFuture.java
@@ -0,0 +1,73 @@
+// Copyright 2022 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote.util;
+
+import com.google.common.util.concurrent.AbstractFuture;
+import io.reactivex.rxjava3.disposables.Disposable;
+import io.reactivex.rxjava3.functions.Action;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.Nullable;
+
+/**
+ * A {@link com.google.common.util.concurrent.ListenableFuture} whose result can be set by a {@link
+ * #set(Object)} or {@link #setException(Throwable)}.
+ *
+ * <p>It differs from {@link com.google.common.util.concurrent.SettableFuture} that it provides
+ * {@link #setCancelCallback(Disposable)} for callers to register a callback which is called when
+ * the future is cancelled.
+ */
+public final class CompletableFuture<T> extends AbstractFuture<T> {
+
+ public static <T> CompletableFuture<T> create() {
+ return new CompletableFuture<>();
+ }
+
+ private final AtomicReference<Disposable> cancelCallback = new AtomicReference<>();
+
+ public void setCancelCallback(Action action) {
+ setCancelCallback(Disposable.fromAction(action));
+ }
+
+ public void setCancelCallback(Disposable cancelCallback) {
+ this.cancelCallback.set(cancelCallback);
+ // Just in case it was already canceled before we set the callback.
+ doCancelIfCancelled();
+ }
+
+ private void doCancelIfCancelled() {
+ if (isCancelled()) {
+ Disposable callback = cancelCallback.getAndSet(null);
+ if (callback != null) {
+ callback.dispose();
+ }
+ }
+ }
+
+ @Override
+ protected void afterDone() {
+ doCancelIfCancelled();
+ }
+
+ // Allow set to be called by other members.
+ @Override
+ public boolean set(@Nullable T t) {
+ return super.set(t);
+ }
+
+ // Allow setException to be called by other members.
+ @Override
+ public boolean setException(Throwable throwable) {
+ return super.setException(throwable);
+ }
+}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
index 01af498..efd3c77 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/RxFutures.java
@@ -15,7 +15,6 @@
import static com.google.common.base.Preconditions.checkState;
-import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
@@ -34,7 +33,6 @@
import java.util.concurrent.CancellationException;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
/** Methods for interoperating between Rx and ListenableFuture. */
@@ -249,39 +247,4 @@
return future;
}
- private static final class CompletableFuture<T> extends AbstractFuture<T> {
- private final AtomicReference<Disposable> cancelCallback = new AtomicReference<>();
-
- private void setCancelCallback(Disposable cancelCallback) {
- this.cancelCallback.set(cancelCallback);
- // Just in case it was already canceled before we set the callback.
- doCancelIfCancelled();
- }
-
- private void doCancelIfCancelled() {
- if (isCancelled()) {
- Disposable callback = cancelCallback.getAndSet(null);
- if (callback != null) {
- callback.dispose();
- }
- }
- }
-
- @Override
- protected void afterDone() {
- doCancelIfCancelled();
- }
-
- // Allow set to be called by other members.
- @Override
- protected boolean set(@Nullable T t) {
- return super.set(t);
- }
-
- // Allow setException to be called by other members.
- @Override
- protected boolean setException(Throwable throwable) {
- return super.setException(throwable);
- }
- }
}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
index d61716c..735d24e 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
@@ -48,6 +48,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.ListenableFuture;
import com.google.devtools.build.lib.actions.ActionInputHelper;
import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
import com.google.devtools.build.lib.actions.util.ActionsTestUtil;
@@ -67,7 +68,9 @@
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
+import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
@@ -152,6 +155,34 @@
}
@Test
+ public void downloadBlob_cancelled_cancelRequest() throws IOException {
+ // Test that if the download future is cancelled, the download itself is also cancelled.
+
+ // arrange
+ Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
+ AtomicBoolean cancelled = new AtomicBoolean();
+ // Mock a byte stream whose read method never finish.
+ serviceRegistry.addService(
+ new ByteStreamImplBase() {
+ @Override
+ public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
+ ((ServerCallStreamObserver<ReadResponse>) responseObserver)
+ .setOnCancelHandler(() -> cancelled.set(true));
+ }
+ });
+ GrpcCacheClient cacheClient = newClient();
+
+ // act
+ try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
+ ListenableFuture<Void> download = cacheClient.downloadBlob(context, digest, out);
+ download.cancel(/* mayInterruptIfRunning= */ true);
+ }
+
+ // assert
+ assertThat(cancelled.get()).isTrue();
+ }
+
+ @Test
public void testDownloadEmptyBlob() throws Exception {
GrpcCacheClient client = newClient();
Digest emptyDigest = DIGEST_UTIL.compute(new byte[0]);
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java
index 39c0aed..55b5e7c 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java
@@ -142,6 +142,27 @@
}
@Test
+ public void downloadFile_cancelled_cancelDownload() throws Exception {
+ // Test that if a download future is cancelled, the download itself is also cancelled.
+
+ // arrange
+ RemoteCacheClient remoteCacheClient = mock(RemoteCacheClient.class);
+ SettableFuture<Void> future = SettableFuture.create();
+ // Return a future that never completes
+ doAnswer(invocationOnMock -> future).when(remoteCacheClient).downloadBlob(any(), any(), any());
+ RemoteCache remoteCache = newRemoteCache(remoteCacheClient);
+ Digest digest = fakeFileCache.createScratchInput(ActionInputHelper.fromPath("file"), "content");
+ Path file = execRoot.getRelative("file");
+
+ // act
+ ListenableFuture<Void> download = remoteCache.downloadFile(context, file, digest);
+ download.cancel(/* mayInterruptIfRunning= */ true);
+
+ // assert
+ assertThat(future.isCancelled()).isTrue();
+ }
+
+ @Test
public void downloadOutErr_empty_doNotPerformDownload() throws Exception {
// Test that downloading empty stdout/stderr does not try to perform a download.
@@ -219,8 +240,7 @@
})
.when(remoteCacheClient)
.uploadFile(any(), any(), any());
- RemoteCache remoteCache =
- new RemoteCache(remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil);
+ RemoteCache remoteCache = newRemoteCache(remoteCacheClient);
Digest digest = fakeFileCache.createScratchInput(ActionInputHelper.fromPath("file"), "content");
Path file = execRoot.getRelative("file");
@@ -253,8 +273,7 @@
invocationOnMock.getArgument(0), invocationOnMock.getArgument(1)))
.when(remoteCacheClient)
.findMissingDigests(any(), any());
- RemoteCache remoteCache =
- new RemoteCache(remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil);
+ RemoteCache remoteCache = newRemoteCache(remoteCacheClient);
Digest digest = fakeFileCache.createScratchInput(ActionInputHelper.fromPath("file"), "content");
Path file = execRoot.getRelative("file");
assertThat(getFromFuture(remoteCache.findMissingDigests(context, ImmutableList.of(digest))))
@@ -326,8 +345,7 @@
doAnswer(invocationOnMock -> SettableFuture.create())
.when(remoteCacheClient)
.uploadFile(any(), any(), any());
- RemoteCache remoteCache =
- new RemoteCache(remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil);
+ RemoteCache remoteCache = newRemoteCache(remoteCacheClient);
Digest digest = fakeFileCache.createScratchInput(ActionInputHelper.fromPath("file"), "content");
Path file = execRoot.getRelative("file");
@@ -342,4 +360,8 @@
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
return new InMemoryRemoteCache(options, digestUtil);
}
+
+ private RemoteCache newRemoteCache(RemoteCacheClient remoteCacheClient) {
+ return new RemoteCache(remoteCacheClient, Options.getDefaults(RemoteOptions.class), digestUtil);
+ }
}