remote: Rewrite the ByteStream upload.

The current ByteStream upload implementation has no support for application-level
flow control, which resulted in excessive buffering and OOM errors.

The new implementation respects gRPCs flow control.

Additionally, this code adds support for multiple uploads of the same
digest. That is, if a digest (i.e. file) is uploaded several times
concurrently, only one upload will be performed.

RELNOTES: None.
PiperOrigin-RevId: 161287337
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
new file mode 100644
index 0000000..19b8929
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -0,0 +1,420 @@
+// Copyright 2017 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote;
+
+import static com.google.devtools.build.lib.util.Preconditions.checkArgument;
+import static com.google.devtools.build.lib.util.Preconditions.checkNotNull;
+import static com.google.devtools.build.lib.util.Preconditions.checkState;
+import static java.lang.String.format;
+import static java.util.Collections.singletonList;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+
+import com.google.bytestream.ByteStreamGrpc;
+import com.google.bytestream.ByteStreamProto.WriteRequest;
+import com.google.bytestream.ByteStreamProto.WriteResponse;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Strings;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListenableScheduledFuture;
+import com.google.common.util.concurrent.ListeningScheduledExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.devtools.remoteexecution.v1test.Digest;
+import com.google.protobuf.ByteString;
+import io.grpc.CallCredentials;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.StatusException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.RejectedExecutionException;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+
+/**
+ * A client implementing the {@code Write} method of the {@code ByteStream} gRPC service.
+ *
+ * <p>Users must call {@link #shutdown()} before exiting.
+ */
+final class ByteStreamUploader {
+
+  private final String instanceName;
+  private final Channel channel;
+  private final CallCredentials callCredentials;
+  private final long callTimeoutSecs;
+  private final Retrier retrier;
+  private final ListeningScheduledExecutorService retryService;
+
+  private final Object lock = new Object();
+
+  @GuardedBy("lock")
+  private final Map<Digest, ListenableFuture<Void>> uploadsInProgress = new HashMap<>();
+
+  @GuardedBy("lock")
+  private boolean isShutdown;
+
+  /**
+   * Creates a new instance.
+   *
+   * @param instanceName the instance name to be prepended to resource name of the {@code Write}
+   *     call. See the {@code ByteStream} service definition for details
+   * @param channel the {@link io.grpc.Channel} to use for calls
+   * @param callCredentials the credentials to use for authentication. May be {@code null}, in which
+   *     case no authentication is performed
+   * @param callTimeoutSecs the timeout in seconds after which a {@code Write} gRPC call must be
+   *     complete. The timeout resets between retries
+   * @param retrier the {@link Retrier} whose backoff strategy to use for retry timings.
+   * @param retryService the executor service to schedule retries on. It's the responsibility of the
+   *     caller to properly shutdown the service after use. Users should avoid shutting down the
+   *     service before {@link #shutdown()} has been called
+   */
+  public ByteStreamUploader(
+      @Nullable String instanceName,
+      Channel channel,
+      @Nullable CallCredentials callCredentials,
+      long callTimeoutSecs,
+      Retrier retrier,
+      ListeningScheduledExecutorService retryService) {
+    checkArgument(callTimeoutSecs > 0, "callTimeoutSecs must be gt 0.");
+
+    this.instanceName = instanceName;
+    this.channel = channel;
+    this.callCredentials = callCredentials;
+    this.callTimeoutSecs = callTimeoutSecs;
+    this.retrier = retrier;
+    this.retryService = retryService;
+  }
+
+  /**
+   * Uploads a BLOB, as provided by the {@link Chunker.SingleSourceBuilder}, to the remote {@code
+   * ByteStream} service. The call blocks until the upload is complete, or throws an {@link
+   * Exception} in case of error.
+   *
+   * <p>Uploads are retried according to the specified {@link Retrier}. Retrying is transparent to
+   * the user of this API.
+   *
+   * <p>Trying to upload the same BLOB multiple times concurrently, results in only one upload being
+   * performed. This is transparent to the user of this API.
+   *
+   * @throws IOException when reading of the {@link Chunker}s input source fails
+   * @throws RetryException when the upload failed after a retry
+   */
+  public void uploadBlob(Chunker.SingleSourceBuilder chunkerBuilder)
+      throws IOException, InterruptedException {
+    uploadBlobs(singletonList(chunkerBuilder));
+  }
+
+  /**
+   * Uploads a list of BLOBs concurrently to the remote {@code ByteStream} service. The call blocks
+   * until the upload of all BLOBs is complete, or throws an {@link Exception} after the first
+   * upload failed. Any other uploads will continue uploading in the background, until they complete
+   * or the {@link #shutdown()} method is called. Errors encountered by these uploads are swallowed.
+   *
+   * <p>Uploads are retried according to the specified {@link Retrier}. Retrying is transparent to
+   * the user of this API.
+   *
+   * <p>Trying to upload the same BLOB multiple times concurrently, results in only one upload being
+   * performed. This is transparent to the user of this API.
+   *
+   * @throws IOException when reading of the {@link Chunker}s input source fails
+   * @throws RetryException when the upload failed after a retry
+   */
+  public void uploadBlobs(Iterable<Chunker.SingleSourceBuilder> chunkerBuilders)
+      throws IOException, InterruptedException {
+    List<ListenableFuture<Void>> uploads = new ArrayList<>();
+
+    for (Chunker.SingleSourceBuilder chunkerBuilder : chunkerBuilders) {
+      uploads.add(uploadBlobAsync(chunkerBuilder));
+    }
+
+    try {
+      for (ListenableFuture<Void> upload : uploads) {
+        upload.get();
+      }
+    } catch (ExecutionException e) {
+      Throwable cause = e.getCause();
+      if (cause instanceof RetryException) {
+        throw (RetryException) cause;
+      } else {
+        throw Throwables.propagate(cause);
+      }
+    } catch (InterruptedException e) {
+      Thread.interrupted();
+      throw e;
+    }
+  }
+
+  /**
+   * Cancels all running uploads. The method returns immediately and does NOT wait for the uploads
+   * to be cancelled.
+   *
+   * <p>This method must be the last method called.
+   */
+  public void shutdown() {
+    synchronized (lock) {
+      if (isShutdown) {
+        return;
+      }
+      isShutdown = true;
+      // Before cancelling, copy the futures to a separate list in order to avoid concurrently
+      // iterating over and modifying the map (cancel triggers a listener that removes the entry
+      // from the map. the listener is executed in the same thread.).
+      List<Future<Void>> uploadsToCancel = new ArrayList<>(uploadsInProgress.values());
+      for (Future<Void> upload : uploadsToCancel) {
+        upload.cancel(true);
+      }
+    }
+  }
+
+  @VisibleForTesting
+  ListenableFuture<Void> uploadBlobAsync(Chunker.SingleSourceBuilder chunkerBuilder)
+      throws IOException {
+    Digest digest = checkNotNull(chunkerBuilder.getDigest());
+
+    synchronized (lock) {
+      checkState(!isShutdown, "Must not call uploadBlobs after shutdown.");
+
+      ListenableFuture<Void> uploadResult = uploadsInProgress.get(digest);
+      if (uploadResult == null) {
+        uploadResult = SettableFuture.create();
+        uploadResult.addListener(
+            () -> {
+              synchronized (lock) {
+                uploadsInProgress.remove(digest);
+              }
+            },
+            MoreExecutors.directExecutor());
+        startAsyncUploadWithRetry(
+            chunkerBuilder, retrier.newBackoff(), (SettableFuture<Void>) uploadResult);
+        uploadsInProgress.put(digest, uploadResult);
+      }
+      return uploadResult;
+    }
+  }
+
+  @VisibleForTesting
+  boolean uploadsInProgress() {
+    synchronized (lock) {
+      return !uploadsInProgress.isEmpty();
+    }
+  }
+
+  private void startAsyncUploadWithRetry(
+      Chunker.SingleSourceBuilder chunkerBuilder,
+      Retrier.Backoff backoffTimes,
+      SettableFuture<Void> overallUploadResult) {
+
+    AsyncUpload.Listener listener =
+        new AsyncUpload.Listener() {
+          @Override
+          public void success() {
+            overallUploadResult.set(null);
+          }
+
+          @Override
+          public void failure(Status status) {
+            StatusException cause = status.asException();
+            long nextDelayMillis = backoffTimes.nextDelayMillis();
+            if (nextDelayMillis < 0 || !retrier.isRetriable(status)) {
+              // Out of retries or status not retriable.
+              RetryException error = new RetryException(cause, backoffTimes.getRetryAttempts());
+              overallUploadResult.setException(error);
+            } else {
+              retryAsyncUpload(nextDelayMillis, chunkerBuilder, backoffTimes, overallUploadResult);
+            }
+          }
+
+          private void retryAsyncUpload(
+              long nextDelayMillis,
+              Chunker.SingleSourceBuilder chunkerBuilder,
+              Retrier.Backoff backoffTimes,
+              SettableFuture<Void> overallUploadResult) {
+            try {
+              ListenableScheduledFuture<?> schedulingResult =
+                  retryService.schedule(
+                      () ->
+                          startAsyncUploadWithRetry(
+                              chunkerBuilder, backoffTimes, overallUploadResult),
+                      nextDelayMillis,
+                      MILLISECONDS);
+              // In case the scheduled execution errors, we need to notify the overallUploadResult.
+              schedulingResult.addListener(
+                  () -> {
+                    try {
+                      schedulingResult.get();
+                    } catch (Exception e) {
+                      overallUploadResult.setException(
+                          new RetryException(e, backoffTimes.getRetryAttempts()));
+                    }
+                  },
+                  MoreExecutors.directExecutor());
+            } catch (RejectedExecutionException e) {
+              // May be thrown by .schedule(...) if i.e. the executor is shutdown.
+              overallUploadResult.setException(
+                  new RetryException(e, backoffTimes.getRetryAttempts()));
+            }
+          }
+        };
+
+    Chunker chunker;
+    try {
+      chunker = chunkerBuilder.build();
+    } catch (IOException e) {
+      overallUploadResult.setException(e);
+      return;
+    }
+
+    AsyncUpload newUpload =
+        new AsyncUpload(channel, callCredentials, callTimeoutSecs, instanceName, chunker, listener);
+    overallUploadResult.addListener(
+        () -> {
+          if (overallUploadResult.isCancelled()) {
+            newUpload.cancel();
+          }
+        },
+        MoreExecutors.directExecutor());
+    newUpload.start();
+  }
+
+  private static class AsyncUpload {
+
+    interface Listener {
+      void success();
+
+      void failure(Status status);
+    }
+
+    private final Channel channel;
+    private final CallCredentials callCredentials;
+    private final long callTimeoutSecs;
+    private final String instanceName;
+    private final Chunker chunker;
+    private final Listener listener;
+
+    private ClientCall<WriteRequest, WriteResponse> call;
+
+    AsyncUpload(
+        Channel channel,
+        CallCredentials callCredentials,
+        long callTimeoutSecs,
+        String instanceName,
+        Chunker chunker,
+        Listener listener) {
+      this.channel = channel;
+      this.callCredentials = callCredentials;
+      this.callTimeoutSecs = callTimeoutSecs;
+      this.instanceName = instanceName;
+      this.chunker = chunker;
+      this.listener = listener;
+    }
+
+    void start() {
+      CallOptions callOptions =
+          CallOptions.DEFAULT
+              .withCallCredentials(callCredentials)
+              .withDeadlineAfter(callTimeoutSecs, SECONDS);
+      call = channel.newCall(ByteStreamGrpc.METHOD_WRITE, callOptions);
+
+      ClientCall.Listener<WriteResponse> callListener =
+          new ClientCall.Listener<WriteResponse>() {
+
+            private final WriteRequest.Builder requestBuilder = WriteRequest.newBuilder();
+            private boolean callHalfClosed = false;
+
+            @Override
+            public void onMessage(WriteResponse response) {
+              // TODO(buchgr): The ByteStream API allows to resume the upload at the committedSize.
+            }
+
+            @Override
+            public void onClose(Status status, Metadata trailers) {
+              if (!status.isOk()) {
+                listener.failure(status);
+              } else {
+                listener.success();
+              }
+            }
+
+            @Override
+            public void onReady() {
+              while (call.isReady()) {
+                if (!chunker.hasNext()) {
+                  // call.halfClose() may only be called once. Guard against it being called more
+                  // often.
+                  // See: https://github.com/grpc/grpc-java/issues/3201
+                  if (!callHalfClosed) {
+                    callHalfClosed = true;
+                    // Every chunk has been written. No more work to do.
+                    call.halfClose();
+                  }
+                  return;
+                }
+
+                try {
+                  requestBuilder.clear();
+                  Chunker.Chunk chunk = chunker.next();
+
+                  if (chunk.getOffset() == 0) {
+                    // Resource name only needs to be set on the first write for each file.
+                    requestBuilder.setResourceName(newResourceName(chunk.getDigest()));
+                  }
+
+                  boolean isLastChunk = !chunker.hasNext();
+                  WriteRequest request =
+                      requestBuilder
+                          .setData(ByteString.copyFrom(chunk.getData()))
+                          .setWriteOffset(chunk.getOffset())
+                          .setFinishWrite(isLastChunk)
+                          .build();
+
+                  call.sendMessage(request);
+                } catch (IOException e) {
+                  call.cancel("Failed to read next chunk.", e);
+                }
+              }
+            }
+
+            private String newResourceName(Digest digest) {
+              String resourceName =
+                  format(
+                      "uploads/%s/blobs/%s/%d",
+                      UUID.randomUUID(), digest.getHash(), digest.getSizeBytes());
+              if (!Strings.isNullOrEmpty(instanceName)) {
+                resourceName = instanceName + "/" + resourceName;
+              }
+              return resourceName;
+            }
+          };
+      call.start(callListener, new Metadata());
+      call.request(1);
+    }
+
+    void cancel() {
+      if (call != null) {
+        call.cancel("Cancelled by user.", null);
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java
index 4446c6b..615e9f8 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java
@@ -16,18 +16,18 @@
 
 import com.google.bytestream.ByteStreamGrpc;
 import com.google.bytestream.ByteStreamGrpc.ByteStreamBlockingStub;
-import com.google.bytestream.ByteStreamGrpc.ByteStreamStub;
 import com.google.bytestream.ByteStreamProto.ReadRequest;
 import com.google.bytestream.ByteStreamProto.ReadResponse;
-import com.google.bytestream.ByteStreamProto.WriteRequest;
-import com.google.bytestream.ByteStreamProto.WriteResponse;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.ListeningScheduledExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.ActionInputFileCache;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
+import com.google.devtools.build.lib.remote.Chunker.SingleSourceBuilder;
 import com.google.devtools.build.lib.remote.Digests.ActionKey;
 import com.google.devtools.build.lib.remote.TreeNodeRepository.TreeNode;
 import com.google.devtools.build.lib.util.Preconditions;
@@ -53,23 +53,16 @@
 import com.google.protobuf.ByteString;
 import io.grpc.Channel;
 import io.grpc.Status;
-import io.grpc.StatusException;
 import io.grpc.StatusRuntimeException;
 import io.grpc.protobuf.StatusProto;
-import io.grpc.stub.StreamObserver;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Set;
-import java.util.UUID;
-import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
 
 /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
 @ThreadSafe
@@ -79,12 +72,20 @@
   private final Channel channel;
   private final Retrier retrier;
 
+  private final ByteStreamUploader uploader;
+
+  private final ListeningScheduledExecutorService retryScheduler =
+      MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1));
+
   @VisibleForTesting
   public GrpcRemoteCache(Channel channel, ChannelOptions channelOptions, RemoteOptions options) {
     this.options = options;
     this.channelOptions = channelOptions;
     this.channel = channel;
     this.retrier = new Retrier(options);
+
+    uploader = new ByteStreamUploader(options.remoteInstanceName, channel,
+        channelOptions.getCallCredentials(), options.remoteTimeout, retrier, retryScheduler);
   }
 
   private ContentAddressableStorageBlockingStub casBlockingStub() {
@@ -99,12 +100,6 @@
         .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
   }
 
-  private ByteStreamStub bsStub() {
-    return ByteStreamGrpc.newStub(channel)
-        .withCallCredentials(channelOptions.getCallCredentials())
-        .withDeadlineAfter(options.remoteTimeout, TimeUnit.SECONDS);
-  }
-
   private ActionCacheBlockingStub acBlockingStub() {
     return ActionCacheGrpc.newBlockingStub(channel)
         .withCallCredentials(channelOptions.getCallCredentials())
@@ -112,7 +107,10 @@
   }
 
   @Override
-  public void close() {}
+  public void close() {
+    retryScheduler.shutdownNow();
+    uploader.shutdown();
+  }
 
   public static boolean isRemoteCacheOptions(RemoteOptions options) {
     return options.remoteCache != null;
@@ -174,11 +172,18 @@
     }
     uploadBlob(command.toByteArray());
     if (!actionInputs.isEmpty()) {
-      uploadChunks(
-          actionInputs.size(),
-          new Chunker.Builder()
-              .addAllInputs(actionInputs, repository.getInputFileCache(), execRoot)
-              .onlyUseDigests(missingDigests));
+      List<Chunker.SingleSourceBuilder> inputsToUpload = new ArrayList<>();
+      ActionInputFileCache inputFileCache = repository.getInputFileCache();
+      for (ActionInput actionInput : actionInputs) {
+        Digest digest = Digests.getDigestFromInputCache(actionInput, inputFileCache);
+        if (missingDigests.contains(digest)) {
+          Chunker.SingleSourceBuilder builder =
+              new Chunker.SingleSourceBuilder().input(actionInput, inputFileCache, execRoot);
+          inputsToUpload.add(builder);
+        }
+      }
+
+      uploader.uploadBlobs(inputsToUpload);
     }
   }
 
@@ -317,7 +322,8 @@
   void upload(Path execRoot, Collection<Path> files, FileOutErr outErr, ActionResult.Builder result)
       throws IOException, InterruptedException {
     ArrayList<Digest> digests = new ArrayList<>();
-    Chunker.Builder b = new Chunker.Builder();
+    ImmutableSet<Digest> digestsToUpload = getMissingDigests(digests);
+    List<Chunker.SingleSourceBuilder> filesToUpload = new ArrayList<>(digestsToUpload.size());
     for (Path file : files) {
       if (!file.exists()) {
         // We ignore requested results that have not been generated by the action.
@@ -328,13 +334,20 @@
         // TreeNodeRepository to call uploadTree.
         throw new UnsupportedOperationException("Storing a directory is not yet supported.");
       }
-      digests.add(Digests.computeDigest(file));
-      b.addInput(file);
+
+      Digest digest = Digests.computeDigest(file);
+      digests.add(digest);
+
+      if (digestsToUpload.contains(digest)) {
+        Chunker.SingleSourceBuilder chunkerBuilder = new SingleSourceBuilder().input(file);
+        filesToUpload.add(chunkerBuilder);
+      }
     }
-    ImmutableSet<Digest> missing = getMissingDigests(digests);
-    if (!missing.isEmpty()) {
-      uploadChunks(missing.size(), b.onlyUseDigests(missing));
+
+    if (!filesToUpload.isEmpty()) {
+      uploader.uploadBlobs(filesToUpload);
     }
+
     int index = 0;
     for (Path file : files) {
       // Add to protobuf.
@@ -366,7 +379,7 @@
     Digest digest = Digests.computeDigest(file);
     ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest));
     if (!missing.isEmpty()) {
-      uploadChunks(1, new Chunker.Builder().addInput(file));
+      uploader.uploadBlob(new Chunker.SingleSourceBuilder().input(file));
     }
     return digest;
   }
@@ -382,108 +395,16 @@
     Digest digest = Digests.getDigestFromInputCache(input, inputCache);
     ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest));
     if (!missing.isEmpty()) {
-      uploadChunks(1, new Chunker.Builder().addInput(input, inputCache, execRoot));
+      uploader.uploadBlob(new Chunker.SingleSourceBuilder().input(input, inputCache, execRoot));
     }
     return digest;
   }
 
-  private void uploadChunks(int numItems, Chunker.Builder chunkerBuilder)
-      throws InterruptedException, IOException {
-    String resourceName = "";
-    if (!options.remoteInstanceName.isEmpty()) {
-      resourceName += options.remoteInstanceName + "/";
-    }
-    Retrier.Backoff backoff = retrier.newBackoff();
-    Chunker chunker = chunkerBuilder.build();
-    while (true) { // Retry until either uploaded everything or raised an exception.
-      CountDownLatch finishLatch = new CountDownLatch(numItems);
-      AtomicReference<IOException> crashException = new AtomicReference<>(null);
-      List<Status> errors = Collections.synchronizedList(new ArrayList<Status>());
-      Set<Digest> failedDigests = Collections.synchronizedSet(new HashSet<Digest>());
-      StreamObserver<WriteRequest> requestObserver = null;
-      while (chunker.hasNext()) {
-        Chunker.Chunk chunk = chunker.next();
-        Digest digest = chunk.getDigest();
-        long offset = chunk.getOffset();
-        WriteRequest.Builder request = WriteRequest.newBuilder();
-        if (offset == 0) { // Beginning of new upload.
-          numItems--;
-          request.setResourceName(
-              String.format(
-                  "%s/uploads/%s/blobs/%s/%d",
-                  resourceName, UUID.randomUUID(), digest.getHash(), digest.getSizeBytes()));
-          // The batches execute simultaneously.
-          requestObserver =
-              bsStub()
-                  .write(
-                      new StreamObserver<WriteResponse>() {
-                        private long bytesLeft = digest.getSizeBytes();
-
-                        @Override
-                        public void onNext(WriteResponse reply) {
-                          bytesLeft -= reply.getCommittedSize();
-                        }
-
-                        @Override
-                        public void onError(Throwable t) {
-                          // In theory, this can be any error, even though it's supposed to usually
-                          // be only StatusException or StatusRuntimeException. We have to check
-                          // for other errors, in order to not accidentally retry them!
-                          if (!(t instanceof StatusRuntimeException
-                              || t instanceof StatusException)) {
-                            crashException.compareAndSet(null, new IOException(t));
-                          }
-
-                          failedDigests.add(digest);
-                          errors.add(Status.fromThrowable(t));
-                          finishLatch.countDown();
-                        }
-
-                        @Override
-                        public void onCompleted() {
-                          // This can actually happen even if we did not send all the bytes,
-                          // if the server has and is able to reuse parts of the uploaded blob.
-                          finishLatch.countDown();
-                        }
-                      });
-        }
-        byte[] data = chunk.getData();
-        boolean finishWrite = offset + data.length == digest.getSizeBytes();
-        request
-            .setData(ByteString.copyFrom(data))
-            .setWriteOffset(offset)
-            .setFinishWrite(finishWrite);
-        requestObserver.onNext(request.build());
-        if (finishWrite) {
-          requestObserver.onCompleted();
-        }
-        if (finishLatch.getCount() <= numItems) {
-          // Current RPC errored before we finished sending.
-          if (!finishWrite) {
-            chunker.advanceInput();
-          }
-        }
-      }
-      finishLatch.await(options.remoteTimeout, TimeUnit.SECONDS);
-      if (crashException.get() != null) {
-        throw crashException.get(); // Re-throw the exception that is supposed to never happen.
-      }
-      if (failedDigests.isEmpty()) {
-        return; // Successfully sent everything.
-      }
-      retrier.onFailures(backoff, errors); // This will throw when out of retries.
-      // We don't have to synchronize on failedDigests now, because after finishLatch.await we're
-      // back to single threaded execution.
-      chunker = chunkerBuilder.onlyUseDigests(failedDigests).build();
-      numItems = failedDigests.size();
-    }
-  }
-
   Digest uploadBlob(byte[] blob) throws IOException, InterruptedException {
     Digest digest = Digests.computeDigest(blob);
     ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest));
     if (!missing.isEmpty()) {
-      uploadChunks(1, new Chunker.Builder().addInput(blob));
+      uploader.uploadBlob(new Chunker.SingleSourceBuilder().input(blob));
     }
     return digest;
   }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/Retrier.java b/src/main/java/com/google/devtools/build/lib/remote/Retrier.java
index cd69ad4..33d2422 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/Retrier.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/Retrier.java
@@ -277,6 +277,13 @@
   }
 
   /**
+   * Returns {@code true} if the {@link Status} is retriable.
+   */
+  public boolean isRetriable(Status s) {
+    return isRetriable.apply(s);
+  }
+
+  /**
    * Executes the given callable in a loop, retrying on retryable errors, as defined by the current
    * backoff/retry policy. Will raise the last encountered retriable error, or the first
    * non-retriable error.
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
new file mode 100644
index 0000000..d9b5d3e
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -0,0 +1,524 @@
+// Copyright 2017 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import com.google.bytestream.ByteStreamGrpc;
+import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
+import com.google.bytestream.ByteStreamProto.WriteRequest;
+import com.google.bytestream.ByteStreamProto.WriteResponse;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningScheduledExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.devtools.build.lib.remote.Chunker.SingleSourceBuilder;
+import com.google.protobuf.ByteString;
+import io.grpc.Channel;
+import io.grpc.Metadata;
+import io.grpc.Server;
+import io.grpc.ServerCall;
+import io.grpc.ServerCall.Listener;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.Status;
+import io.grpc.Status.Code;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.stub.StreamObserver;
+import io.grpc.util.MutableHandlerRegistry;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+/**
+ * Tests for {@link ByteStreamUploader}.
+ */
+@RunWith(JUnit4.class)
+public class ByteStreamUploaderTest {
+
+  private static final int CHUNK_SIZE = 10;
+  private static final String INSTANCE_NAME = "foo";
+
+  private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
+  private final ListeningScheduledExecutorService retryService =
+      MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1));
+
+  private Server server;
+  private Channel channel;
+
+  @Mock
+  private Retrier.Backoff mockBackoff;
+
+  @Before
+  public void init() throws Exception {
+    MockitoAnnotations.initMocks(this);
+
+    String serverName = "Server for " + this.getClass();
+    server = InProcessServerBuilder.forName(serverName).fallbackHandlerRegistry(serviceRegistry)
+        .build().start();
+    channel = InProcessChannelBuilder.forName(serverName).build();
+  }
+
+  @After
+  public void shutdown() {
+    server.shutdownNow();
+    retryService.shutdownNow();
+  }
+
+  @Test(timeout = 10000)
+  public void singleBlobUploadShouldWork() throws Exception {
+    Retrier retrier = new Retrier(() -> mockBackoff, (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    Chunker.SingleSourceBuilder builder =
+        new SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob);
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            return new StreamObserver<WriteRequest>() {
+
+              byte[] receivedData = new byte[blob.length];
+              long nextOffset = 0;
+
+              @Override
+              public void onNext(WriteRequest writeRequest) {
+                if (nextOffset == 0) {
+                  assertThat(writeRequest.getResourceName()).isNotEmpty();
+                  assertThat(writeRequest.getResourceName()).startsWith(INSTANCE_NAME + "/uploads");
+                  assertThat(writeRequest.getResourceName()).endsWith(String.valueOf(blob.length));
+                } else {
+                  assertThat(writeRequest.getResourceName()).isEmpty();
+                }
+
+                assertThat(writeRequest.getWriteOffset()).isEqualTo(nextOffset);
+
+                ByteString data = writeRequest.getData();
+
+                System.arraycopy(data.toByteArray(), 0, receivedData, (int) nextOffset,
+                    data.size());
+
+                nextOffset += data.size();
+                boolean lastWrite = blob.length == nextOffset;
+                assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite);
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                fail("onError should never be called.");
+              }
+
+              @Override
+              public void onCompleted() {
+                assertThat(nextOffset).isEqualTo(blob.length);
+                assertThat(receivedData).isEqualTo(blob);
+
+                WriteResponse response =
+                    WriteResponse.newBuilder().setCommittedSize(nextOffset).build();
+                streamObserver.onNext(response);
+                streamObserver.onCompleted();
+              }
+            };
+          }
+        });
+
+    uploader.uploadBlob(builder);
+
+    // This test should not have triggered any retries.
+    Mockito.verifyZeroInteractions(mockBackoff);
+
+    assertThat(uploader.uploadsInProgress()).isFalse();
+  }
+
+  @Test(timeout = 20000)
+  public void multipleBlobsUploadShouldWork() throws Exception {
+    Retrier retrier = new Retrier(() -> new FixedBackoff(1, 0), (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    int numUploads = 100;
+    Map<String, byte[]> blobsByHash = new HashMap<>();
+    List<Chunker.SingleSourceBuilder> builders = new ArrayList<>(numUploads);
+    Random rand = new Random();
+    for (int i = 0; i < numUploads; i++) {
+      int blobSize = rand.nextInt(CHUNK_SIZE * 10) + CHUNK_SIZE;
+      byte[] blob = new byte[blobSize];
+      rand.nextBytes(blob);
+      Chunker.SingleSourceBuilder builder =
+          new Chunker.SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob);
+      builders.add(builder);
+      blobsByHash.put(builder.getDigest().getHash(), blob);
+    }
+
+    Set<String> uploadsFailedOnce = Collections.synchronizedSet(new HashSet<>());
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        return new StreamObserver<WriteRequest>() {
+
+          private String digestHash;
+          private byte[] receivedData;
+          private long nextOffset;
+
+          @Override
+          public void onNext(WriteRequest writeRequest) {
+            if (nextOffset == 0) {
+              String resourceName = writeRequest.getResourceName();
+              assertThat(resourceName).isNotEmpty();
+
+              String[] components = resourceName.split("/");
+              assertThat(components).hasLength(6);
+              digestHash = components[4];
+              assertThat(blobsByHash).containsKey(digestHash);
+              receivedData = new byte[Integer.parseInt(components[5])];
+            }
+            assertThat(digestHash).isNotNull();
+            // An upload for a given blob has a 10% chance to fail once during its lifetime.
+            // This is to exercise the retry mechanism a bit.
+            boolean shouldFail =
+                rand.nextInt(10) == 0 && !uploadsFailedOnce.contains(digestHash);
+            if (shouldFail) {
+              uploadsFailedOnce.add(digestHash);
+              response.onError(Status.INTERNAL.asException());
+              return;
+            }
+
+            ByteString data = writeRequest.getData();
+            System.arraycopy(
+                data.toByteArray(), 0, receivedData, (int) nextOffset, data.size());
+            nextOffset += data.size();
+
+            boolean lastWrite = nextOffset == receivedData.length;
+            assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite);
+          }
+
+          @Override
+          public void onError(Throwable throwable) {
+            fail("onError should never be called.");
+          }
+
+          @Override
+          public void onCompleted() {
+            byte[] expectedBlob = blobsByHash.get(digestHash);
+            assertThat(receivedData).isEqualTo(expectedBlob);
+
+            WriteResponse writeResponse =
+                WriteResponse.newBuilder().setCommittedSize(receivedData.length).build();
+
+            response.onNext(writeResponse);
+            response.onCompleted();
+          }
+        };
+      }
+    });
+
+    uploader.uploadBlobs(builders);
+
+    assertThat(uploader.uploadsInProgress()).isFalse();
+  }
+
+  @Test(timeout = 10000)
+  public void sameBlobShouldNotBeUploadedTwice() throws Exception {
+    // Test that uploading the same file concurrently triggers only one file upload.
+
+    Retrier retrier = new Retrier(() -> mockBackoff, (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    byte[] blob = new byte[CHUNK_SIZE * 10];
+    Chunker.SingleSourceBuilder builder =
+        new Chunker.SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob);
+
+    AtomicInteger numWriteCalls = new AtomicInteger();
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        numWriteCalls.incrementAndGet();
+
+        return new StreamObserver<WriteRequest>() {
+
+          private long bytesReceived;
+
+          @Override
+          public void onNext(WriteRequest writeRequest) {
+            bytesReceived += writeRequest.getData().size();
+          }
+
+          @Override
+          public void onError(Throwable throwable) {
+            fail("onError should never be called.");
+          }
+
+          @Override
+          public void onCompleted() {
+            response.onNext(WriteResponse.newBuilder().setCommittedSize(bytesReceived).build());
+            response.onCompleted();
+          }
+        };
+      }
+    });
+
+    Future<?> upload1 = uploader.uploadBlobAsync(builder);
+    Future<?> upload2 = uploader.uploadBlobAsync(builder);
+
+    assertThat(upload1).isSameAs(upload2);
+
+    upload1.get();
+
+    assertThat(numWriteCalls.get()).isEqualTo(1);
+  }
+
+  @Test(timeout = 10000)
+  public void errorsShouldBeReported() throws IOException, InterruptedException {
+    Retrier retrier = new Retrier(() -> new FixedBackoff(1, 10), (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    byte[] blob = new byte[CHUNK_SIZE];
+    Chunker.SingleSourceBuilder builder =
+        new Chunker.SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob);
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        response.onError(Status.INTERNAL.asException());
+        return new NoopStreamObserver();
+      }
+    });
+
+    try {
+      uploader.uploadBlob(builder);
+      fail("Should have thrown an exception.");
+    } catch (RetryException e) {
+      assertThat(e.getAttempts()).isEqualTo(2);
+      assertThat(e.causedByStatusCode(Code.INTERNAL)).isTrue();
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void shutdownShouldCancelOngoingUploads() throws Exception {
+    Retrier retrier = new Retrier(() -> new FixedBackoff(1, 10), (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    CountDownLatch cancellations = new CountDownLatch(2);
+
+    ServerServiceDefinition service =
+        ServerServiceDefinition.builder(ByteStreamGrpc.SERVICE_NAME)
+        .addMethod(ByteStreamGrpc.METHOD_WRITE,
+            new ServerCallHandler<WriteRequest, WriteResponse>() {
+              @Override
+              public Listener<WriteRequest> startCall(ServerCall<WriteRequest, WriteResponse> call,
+                  Metadata headers) {
+                // Don't request() any messages from the client, so that the client will be blocked
+                // on flow control and thus the call will sit there idle long enough to receive the
+                // cancellation.
+                return new Listener<WriteRequest>() {
+                  @Override
+                  public void onCancel() {
+                    cancellations.countDown();
+                  }
+                };
+              }
+            })
+        .build();
+
+    serviceRegistry.addService(service);
+
+    byte[] blob1 = new byte[CHUNK_SIZE];
+    Chunker.SingleSourceBuilder builder1 =
+        new Chunker.SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob1);
+
+    byte[] blob2 = new byte[CHUNK_SIZE + 1];
+    Chunker.SingleSourceBuilder builder2 =
+        new Chunker.SingleSourceBuilder().chunkSize(CHUNK_SIZE).input(blob2);
+
+    ListenableFuture<Void> f1 = uploader.uploadBlobAsync(builder1);
+    ListenableFuture<Void> f2 = uploader.uploadBlobAsync(builder2);
+
+    assertThat(uploader.uploadsInProgress()).isTrue();
+
+    uploader.shutdown();
+
+    cancellations.await();
+
+    assertThat(f1.isCancelled()).isTrue();
+    assertThat(f2.isCancelled()).isTrue();
+
+    assertThat(uploader.uploadsInProgress()).isFalse();
+  }
+
+  @Test(timeout = 10000)
+  public void failureInRetryExecutorShouldBeHandled() throws Exception {
+    Retrier retrier = new Retrier(() -> new FixedBackoff(1, 10), (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier, retryService);
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        // Immediately fail the call, so that it is retried.
+        response.onError(Status.ABORTED.asException());
+        return new NoopStreamObserver();
+      }
+    });
+
+    retryService.shutdownNow();
+    // Random very high timeout, as the test will timeout by itself.
+    retryService.awaitTermination(1, TimeUnit.DAYS);
+    assertThat(retryService.isShutdown()).isTrue();
+
+    byte[] blob = new byte[1];
+    Chunker.SingleSourceBuilder builder = new Chunker.SingleSourceBuilder().input(blob);
+    try {
+      uploader.uploadBlob(builder);
+      fail("Should have thrown an exception.");
+    } catch (RetryException e) {
+      assertThat(e).hasCauseThat().isInstanceOf(RejectedExecutionException.class);
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void resourceNameWithoutInstanceName() throws Exception {
+    Retrier retrier = new Retrier(() -> mockBackoff, (Status s) -> true);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(/* instanceName */ null, channel, null, 3, retrier, retryService);
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        return new StreamObserver<WriteRequest>() {
+          @Override
+          public void onNext(WriteRequest writeRequest) {
+            // Test that the resource name doesn't start with an instance name.
+            assertThat(writeRequest.getResourceName()).startsWith("uploads/");
+          }
+
+          @Override
+          public void onError(Throwable throwable) {
+
+          }
+
+          @Override
+          public void onCompleted() {
+            response.onNext(WriteResponse.newBuilder().setCommittedSize(1).build());
+            response.onCompleted();
+          }
+        };
+      }
+    });
+
+    byte[] blob = new byte[1];
+    Chunker.SingleSourceBuilder builder = new Chunker.SingleSourceBuilder().input(blob);
+
+    uploader.uploadBlob(builder);
+  }
+
+  @Test(timeout = 10000)
+  public void nonRetryableStatusShouldNotBeRetried() throws Exception {
+    Retrier retrier = new Retrier(() -> new FixedBackoff(1, 0),
+        /* No Status is retriable. */ (Status s) -> false);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(/* instanceName */ null, channel, null, 3, retrier, retryService);
+
+    AtomicInteger numCalls = new AtomicInteger();
+
+    serviceRegistry.addService(new ByteStreamImplBase() {
+      @Override
+      public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) {
+        numCalls.incrementAndGet();
+        response.onError(Status.INTERNAL.asException());
+        return new NoopStreamObserver();
+      }
+    });
+
+    byte[] blob = new byte[1];
+    Chunker.SingleSourceBuilder builder = new Chunker.SingleSourceBuilder().input(blob);
+
+    try {
+      uploader.uploadBlob(builder);
+      fail("Should have thrown an exception.");
+    } catch (RetryException e) {
+      assertThat(numCalls.get()).isEqualTo(1);
+    }
+  }
+
+  private static class NoopStreamObserver implements StreamObserver<WriteRequest> {
+    @Override
+    public void onNext(WriteRequest writeRequest) {
+    }
+
+    @Override
+    public void onError(Throwable throwable) {
+    }
+
+    @Override
+    public void onCompleted() {
+    }
+  }
+
+  private static class FixedBackoff implements Retrier.Backoff {
+
+    private final int maxRetries;
+    private final int delayMillis;
+
+    private int retries;
+
+    public FixedBackoff(int maxRetries, int delayMillis) {
+      this.maxRetries = maxRetries;
+      this.delayMillis = delayMillis;
+    }
+
+    @Override
+    public long nextDelayMillis() {
+      if (retries < maxRetries) {
+        retries++;
+        return delayMillis;
+      }
+      return -1;
+    }
+
+    @Override
+    public int getRetryAttempts() {
+      return retries;
+    }
+  }
+}