Restructure blob upload code.

The previous blob upload code had a lot of nesting and state shared between Future callbacks that made it hard to follow and modify. There is some fundamental complexity in the problem this code is solving; namely, retrying uploads can resume from the last comitted offset rather than always the beginning of the blob. I hope my rewrite makes the code easier to understand and paves the way for some future behavior modifications I want to make in this area.

Closes #15514.

PiperOrigin-RevId: 454588429
Change-Id: I9bc9c6942534053f19368fe4240dfe074ce8e60d
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
index e5b817a..34d1213 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -23,6 +23,7 @@
 import build.bazel.remote.execution.v2.Digest;
 import com.google.bytestream.ByteStreamGrpc;
 import com.google.bytestream.ByteStreamGrpc.ByteStreamFutureStub;
+import com.google.bytestream.ByteStreamGrpc.ByteStreamStub;
 import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest;
 import com.google.bytestream.ByteStreamProto.QueryWriteStatusResponse;
 import com.google.bytestream.ByteStreamProto.WriteRequest;
@@ -30,7 +31,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Ascii;
 import com.google.common.base.Strings;
-import com.google.common.flogger.GoogleLogger;
+import com.google.common.util.concurrent.AsyncCallable;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
@@ -40,13 +41,14 @@
 import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
 import com.google.devtools.build.lib.remote.util.Utils;
-import io.grpc.CallOptions;
 import io.grpc.Channel;
-import io.grpc.ClientCall;
-import io.grpc.Metadata;
+import io.grpc.Context;
+import io.grpc.Context.CancellableContext;
 import io.grpc.Status;
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
+import io.grpc.stub.ClientCallStreamObserver;
+import io.grpc.stub.ClientResponseObserver;
 import io.netty.util.ReferenceCounted;
 import java.io.IOException;
 import java.util.ArrayList;
@@ -54,7 +56,6 @@
 import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.Semaphore;
-import java.util.concurrent.atomic.AtomicLong;
 import javax.annotation.Nullable;
 
 /**
@@ -66,9 +67,6 @@
  * <p>See {@link ReferenceCounted} for more information on reference counting.
  */
 class ByteStreamUploader {
-
-  private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
-
   private final String instanceName;
   private final ReferenceCountedChannel channel;
   private final CallCredentialsProvider callCredentialsProvider;
@@ -211,15 +209,6 @@
     UUID uploadId = UUID.randomUUID();
     String resourceName =
         buildUploadResourceName(instanceName, uploadId, digest, chunker.isCompressed());
-    AsyncUpload newUpload =
-        new AsyncUpload(
-            context,
-            channel,
-            callCredentialsProvider,
-            callTimeoutSecs,
-            retrier,
-            resourceName,
-            chunker);
     if (openedFilePermits != null) {
       try {
         openedFilePermits.acquire();
@@ -230,19 +219,28 @@
                     + e.getMessage()));
       }
     }
+    AsyncUpload newUpload =
+        new AsyncUpload(
+            context,
+            channel,
+            callCredentialsProvider,
+            callTimeoutSecs,
+            retrier,
+            resourceName,
+            chunker);
     ListenableFuture<Void> currUpload = newUpload.start();
     currUpload.addListener(
         () -> {
-          if (currUpload.isCancelled()) {
-            newUpload.cancel();
+          newUpload.cancel();
+          if (openedFilePermits != null) {
+            openedFilePermits.release();
           }
         },
         MoreExecutors.directExecutor());
     return currUpload;
   }
 
-  private class AsyncUpload {
-
+  private static final class AsyncUpload implements AsyncCallable<Long> {
     private final RemoteActionExecutionContext context;
     private final ReferenceCountedChannel channel;
     private final CallCredentialsProvider callCredentialsProvider;
@@ -250,8 +248,10 @@
     private final Retrier retrier;
     private final String resourceName;
     private final Chunker chunker;
+    private final ProgressiveBackoff progressiveBackoff;
+    private final CancellableContext grpcContext;
 
-    private ClientCall<WriteRequest, WriteResponse> call;
+    private long lastCommittedOffset = -1;
 
     AsyncUpload(
         RemoteActionExecutionContext context,
@@ -266,95 +266,110 @@
       this.callCredentialsProvider = callCredentialsProvider;
       this.callTimeoutSecs = callTimeoutSecs;
       this.retrier = retrier;
+      this.progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
       this.resourceName = resourceName;
       this.chunker = chunker;
+      this.grpcContext = Context.current().withCancellation();
     }
 
     ListenableFuture<Void> start() {
-      ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
-      AtomicLong committedOffset = new AtomicLong(0);
-
-      ListenableFuture<Void> callFuture =
-          Utils.refreshIfUnauthenticatedAsync(
-              () ->
-                  retrier.executeAsync(
-                      () -> {
-                        if (chunker.getSize() == committedOffset.get()) {
-                          return immediateVoidFuture();
-                        }
-                        try {
-                          chunker.seek(committedOffset.get());
-                        } catch (IOException e) {
-                          try {
-                            chunker.reset();
-                          } catch (IOException resetException) {
-                            e.addSuppressed(resetException);
-                          }
-                          String tooManyOpenFilesError = "Too many open files";
-                          if (Ascii.toLowerCase(e.getMessage())
-                              .contains(Ascii.toLowerCase(tooManyOpenFilesError))) {
-                            String newMessage =
-                                "An IOException was thrown because the process opened too many"
-                                    + " files. We recommend setting"
-                                    + " --bep_maximum_open_remote_upload_files flag to a number"
-                                    + " lower than your system default (run 'ulimit -a' for"
-                                    + " *nix-based operating systems). Original error message: "
-                                    + e.getMessage();
-                            return Futures.immediateFailedFuture(new IOException(newMessage, e));
-                          }
-                          return Futures.immediateFailedFuture(e);
-                        }
-                        if (chunker.hasNext()) {
-                          return callAndQueryOnFailure(committedOffset, progressiveBackoff);
-                        }
-                        return immediateVoidFuture();
-                      },
-                      progressiveBackoff),
-              callCredentialsProvider);
-      if (openedFilePermits != null) {
-        callFuture.addListener(openedFilePermits::release, MoreExecutors.directExecutor());
-      }
       return Futures.transformAsync(
-          callFuture,
-          (result) -> {
-            if (!chunker.hasNext()) {
-              // Only check for matching committed size if we have completed the upload.
-              // If another client did, they might have used a different compression
-              // level/algorithm, so we cannot know the expected committed offset
-              long committedSize = committedOffset.get();
-              long expected = chunker.getOffset();
-
-              if (committedSize == expected) {
-                // Both compressed and uncompressed uploads can succeed
-                // with this result.
-                return immediateVoidFuture();
-              }
-
-              if (chunker.isCompressed()) {
-                if (committedSize == -1) {
-                  // Returned early, blob already available.
-                  return immediateVoidFuture();
-                }
-
-                String message =
-                    format(
-                        "compressed write incomplete: committed_size %d is neither -1 nor total %d",
-                        committedSize, expected);
-                return Futures.immediateFailedFuture(new IOException(message));
-              }
-
-              // Uncompressed upload failed.
-              String message =
-                  format(
-                      "write incomplete: committed_size %d for %d total", committedSize, expected);
-              return Futures.immediateFailedFuture(new IOException(message));
+          Utils.refreshIfUnauthenticatedAsync(
+              () -> retrier.executeAsync(this, progressiveBackoff), callCredentialsProvider),
+          committedSize -> {
+            try {
+              checkCommittedSize(committedSize);
+            } catch (IOException e) {
+              return Futures.immediateFailedFuture(e);
             }
-
             return immediateVoidFuture();
           },
           MoreExecutors.directExecutor());
     }
 
+    private void checkCommittedSize(long committedSize) throws IOException {
+      // Only check for matching committed size if we have completed the upload.  If another client
+      // did, they might have used a different compression level/algorithm, so we cannot know the
+      // expected committed offset
+      if (chunker.hasNext()) {
+        return;
+      }
+
+      long expected = chunker.getOffset();
+
+      if (committedSize == expected) {
+        // Both compressed and uncompressed uploads can succeed with this result.
+        return;
+      }
+
+      if (chunker.isCompressed()) {
+        if (committedSize == -1) {
+          // Returned early, blob already available.
+          return;
+        }
+
+        throw new IOException(
+            format(
+                "compressed write incomplete: committed_size %d is" + " neither -1 nor total %d",
+                committedSize, expected));
+      }
+
+      // Uncompressed upload failed.
+      throw new IOException(
+          format("write incomplete: committed_size %d for %d total", committedSize, expected));
+    }
+
+    /**
+     * Make one attempt to upload. If this is the first attempt, uploading starts from the beginning
+     * of the blob. On later attempts, the server is queried to see at which offset upload should
+     * resume. The final committed size from the server is returned on success.
+     */
+    @Override
+    public ListenableFuture<Long> call() {
+      boolean firstAttempt = lastCommittedOffset == -1;
+      return Futures.transformAsync(
+          firstAttempt ? Futures.immediateFuture(0L) : query(),
+          committedSize -> {
+            if (!firstAttempt) {
+              if (chunker.getSize() == committedSize) {
+                return Futures.immediateFuture(committedSize);
+              }
+              if (committedSize > lastCommittedOffset) {
+                // We have made progress on this upload in the last request. Reset the backoff so
+                // that
+                // this request has a full deck of retries
+                progressiveBackoff.reset();
+              }
+            }
+            lastCommittedOffset = committedSize;
+            try {
+              chunker.seek(committedSize);
+            } catch (IOException e) {
+              try {
+                chunker.reset();
+              } catch (IOException resetException) {
+                e.addSuppressed(resetException);
+              }
+              String tooManyOpenFilesError = "Too many open files";
+              if (Ascii.toLowerCase(e.getMessage())
+                  .contains(Ascii.toLowerCase(tooManyOpenFilesError))) {
+                String newMessage =
+                    "An IOException was thrown because the process opened too"
+                        + " many files. We recommend setting"
+                        + " --bep_maximum_open_remote_upload_files flag to a"
+                        + " number lower than your system default (run 'ulimit"
+                        + " -a' for *nix-based operating systems). Original"
+                        + " error message: "
+                        + e.getMessage();
+                return Futures.immediateFailedFuture(new IOException(newMessage, e));
+              }
+              return Futures.immediateFailedFuture(e);
+            }
+            return upload();
+          },
+          MoreExecutors.directExecutor());
+    }
+
     private ByteStreamFutureStub bsFutureStub(Channel channel) {
       return ByteStreamGrpc.newFutureStub(channel)
           .withInterceptors(
@@ -363,183 +378,126 @@
           .withDeadlineAfter(callTimeoutSecs, SECONDS);
     }
 
-    private ListenableFuture<Void> callAndQueryOnFailure(
-        AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
-      return Futures.catchingAsync(
-          Futures.transform(
-              channel.withChannelFuture(channel -> call(committedOffset, channel)),
-              written -> null,
-              MoreExecutors.directExecutor()),
-          Exception.class,
-          (e) -> guardQueryWithSuppression(e, committedOffset, progressiveBackoff),
-          MoreExecutors.directExecutor());
+    private ByteStreamStub bsAsyncStub(Channel channel) {
+      return ByteStreamGrpc.newStub(channel)
+          .withInterceptors(
+              TracingMetadataUtils.attachMetadataInterceptor(context.getRequestMetadata()))
+          .withCallCredentials(callCredentialsProvider.getCallCredentials())
+          .withDeadlineAfter(callTimeoutSecs, SECONDS);
     }
 
-    private ListenableFuture<Void> guardQueryWithSuppression(
-        Exception e, AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
-      // we are destined to return this, avoid recreating it
-      ListenableFuture<Void> exceptionFuture = Futures.immediateFailedFuture(e);
-
-      // TODO(buchgr): we should also return immediately without the query if
-      // we were out of retry attempts for the underlying backoff. This
-      // is meant to be an only in-between-retries query request.
-      if (!retrier.isRetriable(e)) {
-        return exceptionFuture;
-      }
-
-      ListenableFuture<Void> suppressedQueryFuture =
-          Futures.catchingAsync(
-              query(committedOffset, progressiveBackoff),
-              Exception.class,
-              (queryException) -> {
-                // if the query threw an exception, add it to the suppressions
-                // for the destined exception
-                e.addSuppressed(queryException);
-                return exceptionFuture;
-              },
-              MoreExecutors.directExecutor());
-      return Futures.transformAsync(
-          suppressedQueryFuture, (result) -> exceptionFuture, MoreExecutors.directExecutor());
-    }
-
-    private ListenableFuture<Void> query(
-        AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
+    private ListenableFuture<Long> query() {
       ListenableFuture<Long> committedSizeFuture =
           Futures.transform(
               channel.withChannelFuture(
                   channel ->
-                      bsFutureStub(channel)
-                          .queryWriteStatus(
-                              QueryWriteStatusRequest.newBuilder()
-                                  .setResourceName(resourceName)
-                                  .build())),
+                      grpcContext.call(
+                          () ->
+                              bsFutureStub(channel)
+                                  .queryWriteStatus(
+                                      QueryWriteStatusRequest.newBuilder()
+                                          .setResourceName(resourceName)
+                                          .build()))),
               QueryWriteStatusResponse::getCommittedSize,
               MoreExecutors.directExecutor());
-      ListenableFuture<Long> guardedCommittedSizeFuture =
-          Futures.catchingAsync(
-              committedSizeFuture,
-              Exception.class,
-              (e) -> {
-                Status status = Status.fromThrowable(e);
-                if (status.getCode() == Code.UNIMPLEMENTED) {
-                  // if the bytestream server does not implement the query, insist
-                  // that we should reset the upload
-                  return Futures.immediateFuture(0L);
-                }
-                return Futures.immediateFailedFuture(e);
-              },
-              MoreExecutors.directExecutor());
-      return Futures.transformAsync(
-          guardedCommittedSizeFuture,
-          (committedSize) -> {
-            if (committedSize > committedOffset.get()) {
-              // we have made progress on this upload in the last request,
-              // reset the backoff so that this request has a full deck of retries
-              progressiveBackoff.reset();
+      return Futures.catchingAsync(
+          committedSizeFuture,
+          Exception.class,
+          (e) -> {
+            Status status = Status.fromThrowable(e);
+            if (status.getCode() == Code.UNIMPLEMENTED) {
+              // if the bytestream server does not implement the query, insist
+              // that we should reset the upload
+              return Futures.immediateFuture(0L);
             }
-            committedOffset.set(committedSize);
-            return immediateVoidFuture();
+            return Futures.immediateFailedFuture(e);
           },
           MoreExecutors.directExecutor());
     }
 
-    private ListenableFuture<Long> call(AtomicLong committedOffset, Channel channel) {
-      CallOptions callOptions =
-          CallOptions.DEFAULT
-              .withCallCredentials(callCredentialsProvider.getCallCredentials())
-              .withDeadlineAfter(callTimeoutSecs, SECONDS);
-      call = channel.newCall(ByteStreamGrpc.getWriteMethod(), callOptions);
-
-      SettableFuture<Long> uploadResult = SettableFuture.create();
-      ClientCall.Listener<WriteResponse> callListener =
-          new ClientCall.Listener<WriteResponse>() {
-
-            private final WriteRequest.Builder requestBuilder = WriteRequest.newBuilder();
-            private boolean callHalfClosed = false;
-
-            void halfClose() {
-              // call.halfClose() may only be called once. Guard against it being called more
-              // often.
-              // See: https://github.com/grpc/grpc-java/issues/3201
-              if (!callHalfClosed) {
-                callHalfClosed = true;
-                // Every chunk has been written. No more work to do.
-                call.halfClose();
-              }
-            }
-
-            @Override
-            public void onMessage(WriteResponse response) {
-              // upload was completed either by us or someone else
-              committedOffset.set(response.getCommittedSize());
-              halfClose();
-            }
-
-            @Override
-            public void onClose(Status status, Metadata trailers) {
-              if (status.isOk()) {
-                uploadResult.set(committedOffset.get());
-              } else {
-                uploadResult.setException(status.asRuntimeException());
-              }
-            }
-
-            @Override
-            public void onReady() {
-              while (call.isReady()) {
-                if (!chunker.hasNext()) {
-                  halfClose();
-                  return;
-                }
-
-                if (callHalfClosed) {
-                  return;
-                }
-
-                try {
-                  requestBuilder.clear();
-                  Chunker.Chunk chunk = chunker.next();
-
-                  if (chunk.getOffset() == committedOffset.get()) {
-                    // Resource name only needs to be set on the first write for each file.
-                    requestBuilder.setResourceName(resourceName);
-                  }
-
-                  boolean isLastChunk = !chunker.hasNext();
-                  WriteRequest request =
-                      requestBuilder
-                          .setData(chunk.getData())
-                          .setWriteOffset(chunk.getOffset())
-                          .setFinishWrite(isLastChunk)
-                          .build();
-
-                  call.sendMessage(request);
-                } catch (IOException e) {
-                  try {
-                    chunker.reset();
-                  } catch (IOException e1) {
-                    // This exception indicates that closing the underlying input stream failed.
-                    // We don't expect this to ever happen, but don't want to swallow the exception
-                    // completely.
-                    logger.atWarning().withCause(e1).log("Chunker failed closing data source.");
-                  } finally {
-                    call.cancel("Failed to read next chunk.", e);
-                  }
-                }
-              }
-            }
-          };
-      call.start(
-          callListener,
-          TracingMetadataUtils.headersFromRequestMetadata(context.getRequestMetadata()));
-      call.request(1);
-      return uploadResult;
+    private ListenableFuture<Long> upload() {
+      return channel.withChannelFuture(
+          channel -> {
+            SettableFuture<Long> uploadResult = SettableFuture.create();
+            grpcContext.run(
+                () -> bsAsyncStub(channel).write(new Writer(resourceName, chunker, uploadResult)));
+            return uploadResult;
+          });
     }
 
     void cancel() {
-      if (call != null) {
-        call.cancel("Cancelled by user.", null);
+      grpcContext.cancel(
+          Status.CANCELLED.withDescription("Cancelled by user").asRuntimeException());
+    }
+  }
+
+  private static final class Writer
+      implements ClientResponseObserver<WriteRequest, WriteResponse>, Runnable {
+    private final Chunker chunker;
+    private final String resourceName;
+    private final SettableFuture<Long> uploadResult;
+    private long committedSize = -1;
+    private ClientCallStreamObserver<WriteRequest> requestObserver;
+    private boolean first = true;
+
+    private Writer(String resourceName, Chunker chunker, SettableFuture<Long> uploadResult) {
+      this.resourceName = resourceName;
+      this.chunker = chunker;
+      this.uploadResult = uploadResult;
+    }
+
+    @Override
+    public void beforeStart(ClientCallStreamObserver<WriteRequest> requestObserver) {
+      this.requestObserver = requestObserver;
+      requestObserver.setOnReadyHandler(this);
+    }
+
+    @Override
+    public void run() {
+      if (committedSize != -1) {
+        requestObserver.cancel("server has returned early", null);
+        return;
       }
+      while (requestObserver.isReady()) {
+        Chunker.Chunk chunk;
+        try {
+          chunk = chunker.next();
+        } catch (IOException e) {
+          requestObserver.cancel("Failed to read next chunk.", e);
+          return;
+        }
+        boolean isLastChunk = !chunker.hasNext();
+        WriteRequest.Builder request =
+            WriteRequest.newBuilder()
+                .setData(chunk.getData())
+                .setWriteOffset(chunk.getOffset())
+                .setFinishWrite(isLastChunk);
+        if (first) {
+          first = false;
+          // Resource name only needs to be set on the first write for each file.
+          request.setResourceName(resourceName);
+        }
+        requestObserver.onNext(request.build());
+        if (isLastChunk) {
+          requestObserver.onCompleted();
+          return;
+        }
+      }
+    }
+
+    @Override
+    public void onNext(WriteResponse response) {
+      committedSize = response.getCommittedSize();
+    }
+
+    @Override
+    public void onCompleted() {
+      uploadResult.set(committedSize);
+    }
+
+    @Override
+    public void onError(Throwable t) {
+      uploadResult.setException(t);
     }
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index ac1421d..de2ff4d 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -353,8 +353,9 @@
 
     uploader.uploadBlob(context, digest, chunker);
 
-    // This test should not have triggered any retries.
-    Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
+    // This test triggers one retry.
+    Mockito.verify(mockBackoff, Mockito.times(1))
+        .nextDelayMillis(any(StatusRuntimeException.class));
     Mockito.verify(mockBackoff, Mockito.times(1)).getRetryAttempts();
   }
 
@@ -476,8 +477,8 @@
 
     uploader.uploadBlob(context, digest, chunker);
 
-    // This test should not have triggered any retries.
-    Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
+    // This test triggers one retry.
+    Mockito.verify(mockBackoff, Mockito.times(1)).nextDelayMillis(any(Exception.class));
     Mockito.verify(mockBackoff, Mockito.times(1)).getRetryAttempts();
   }
 
@@ -703,7 +704,7 @@
   }
 
   @Test
-  public void incorrectCommittedSizeDoesNotFailsIncompleteUpload() throws Exception {
+  public void incorrectCommittedSizeDoesNotFailIncompleteUpload() throws Exception {
     RemoteRetrier retrier =
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =