Remote: Add support for compression on gRPC cache

Add support for compressed transfers from/to gRPC remote caches with flag --experimental_remote_cache_compression.

Fixes #13344.

Closes #14041.

PiperOrigin-RevId: 409328001
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index 748a341..15cc335 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -21,6 +21,8 @@
 
 import build.bazel.remote.execution.v2.Digest;
 import build.bazel.remote.execution.v2.RequestMetadata;
+import com.github.luben.zstd.Zstd;
+import com.github.luben.zstd.ZstdInputStream;
 import com.google.bytestream.ByteStreamGrpc;
 import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
 import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest;
@@ -63,6 +65,7 @@
 import io.grpc.util.MutableHandlerRegistry;
 import io.reactivex.rxjava3.core.Single;
 import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.ArrayList;
@@ -342,6 +345,130 @@
   }
 
   @Test
+  public void progressiveCompressedUploadShouldWork() throws Exception {
+    Mockito.when(mockBackoff.getRetryAttempts()).thenReturn(0);
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channelConnectionFactory),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            300,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    Chunker chunker =
+        Chunker.builder().setInput(blob).setCompressed(true).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    while (chunker.hasNext()) {
+      chunker.next();
+    }
+    long expectedSize = chunker.getOffset();
+    chunker.reset();
+
+    serviceRegistry.addService(
+        new ByteStreamImplBase() {
+
+          byte[] receivedData = new byte[(int) expectedSize];
+          String receivedResourceName = null;
+          boolean receivedComplete = false;
+          long nextOffset = 0;
+          long initialOffset = 0;
+          boolean mustQueryWriteStatus = false;
+
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            return new StreamObserver<WriteRequest>() {
+              @Override
+              public void onNext(WriteRequest writeRequest) {
+                assertThat(mustQueryWriteStatus).isFalse();
+
+                String resourceName = writeRequest.getResourceName();
+                if (nextOffset == initialOffset) {
+                  if (initialOffset == 0) {
+                    receivedResourceName = resourceName;
+                  }
+                  assertThat(resourceName).startsWith(INSTANCE_NAME + "/uploads");
+                  assertThat(resourceName).endsWith(String.valueOf(blob.length));
+                } else {
+                  assertThat(resourceName).isEmpty();
+                }
+
+                assertThat(writeRequest.getWriteOffset()).isEqualTo(nextOffset);
+
+                ByteString data = writeRequest.getData();
+
+                System.arraycopy(
+                    data.toByteArray(), 0, receivedData, (int) nextOffset, data.size());
+
+                nextOffset += data.size();
+                receivedComplete = expectedSize == nextOffset;
+                assertThat(writeRequest.getFinishWrite()).isEqualTo(receivedComplete);
+
+                if (initialOffset == 0) {
+                  streamObserver.onError(Status.DEADLINE_EXCEEDED.asException());
+                  mustQueryWriteStatus = true;
+                  initialOffset = nextOffset;
+                }
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                fail("onError should never be called.");
+              }
+
+              @Override
+              public void onCompleted() {
+                assertThat(nextOffset).isEqualTo(expectedSize);
+                byte[] decompressed = Zstd.decompress(receivedData, blob.length);
+                assertThat(decompressed).isEqualTo(blob);
+
+                WriteResponse response =
+                    WriteResponse.newBuilder().setCommittedSize(nextOffset).build();
+                streamObserver.onNext(response);
+                streamObserver.onCompleted();
+              }
+            };
+          }
+
+          @Override
+          public void queryWriteStatus(
+              QueryWriteStatusRequest request, StreamObserver<QueryWriteStatusResponse> response) {
+            String resourceName = request.getResourceName();
+            final long committedSize;
+            final boolean complete;
+            if (receivedResourceName != null && receivedResourceName.equals(resourceName)) {
+              assertThat(mustQueryWriteStatus).isTrue();
+              mustQueryWriteStatus = false;
+              committedSize = nextOffset;
+              complete = receivedComplete;
+            } else {
+              committedSize = 0;
+              complete = false;
+            }
+            response.onNext(
+                QueryWriteStatusResponse.newBuilder()
+                    .setCommittedSize(committedSize)
+                    .setComplete(complete)
+                    .build());
+            response.onCompleted();
+          }
+        });
+
+    uploader.uploadBlob(context, hash, chunker, true);
+
+    // This test should not have triggered any retries.
+    Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
+    Mockito.verify(mockBackoff, Mockito.times(1)).getRetryAttempts();
+
+    blockUntilInternalStateConsistent(uploader);
+  }
+
+  @Test
   public void concurrentlyCompletedUploadIsNotRetried() throws Exception {
     // Test that after an upload has failed and the QueryWriteStatus call returns
     // that the upload has completed that we'll not retry the upload.
@@ -512,7 +639,7 @@
   }
 
   @Test
-  public void incorrectCommittedSizeFailsUpload() throws Exception {
+  public void incorrectCommittedSizeFailsCompletedUpload() throws Exception {
     RemoteRetrier retrier =
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =
@@ -533,10 +660,23 @@
         new ByteStreamImplBase() {
           @Override
           public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
-            streamObserver.onNext(
-                WriteResponse.newBuilder().setCommittedSize(blob.length + 1).build());
-            streamObserver.onCompleted();
-            return new NoopStreamObserver();
+            return new StreamObserver<WriteRequest>() {
+              @Override
+              public void onNext(WriteRequest writeRequest) {}
+
+              @Override
+              public void onError(Throwable throwable) {
+                fail("onError should never be called.");
+              }
+
+              @Override
+              public void onCompleted() {
+                WriteResponse response =
+                    WriteResponse.newBuilder().setCommittedSize(blob.length + 1).build();
+                streamObserver.onNext(response);
+                streamObserver.onCompleted();
+              }
+            };
           }
         });
 
@@ -554,6 +694,38 @@
   }
 
   @Test
+  public void incorrectCommittedSizeDoesNotFailsIncompleteUpload() throws Exception {
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channelConnectionFactory),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            300,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    serviceRegistry.addService(
+        new ByteStreamImplBase() {
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            streamObserver.onNext(WriteResponse.newBuilder().setCommittedSize(CHUNK_SIZE).build());
+            streamObserver.onCompleted();
+            return new NoopStreamObserver();
+          }
+        });
+
+    uploader.uploadBlob(context, hash, chunker, true);
+    blockUntilInternalStateConsistent(uploader);
+  }
+
+  @Test
   public void multipleBlobsUploadShouldWork() throws Exception {
     RemoteRetrier retrier =
         TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);
@@ -1345,6 +1517,99 @@
     blockUntilInternalStateConsistent(uploader);
   }
 
+  @Test
+  public void testCompressedUploads() throws Exception {
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channelConnectionFactory),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            /* callTimeoutSecs= */ 60,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    AtomicInteger numUploads = new AtomicInteger();
+
+    serviceRegistry.addService(
+        new ByteStreamImplBase() {
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            return new StreamObserver<WriteRequest>() {
+              ByteArrayOutputStream baos = new ByteArrayOutputStream();
+              String resourceName = null;
+
+              @Override
+              public void onNext(WriteRequest writeRequest) {
+                if (!writeRequest.getResourceName().isEmpty()) {
+                  if (resourceName != null) {
+                    assertThat(resourceName).isEqualTo(writeRequest.getResourceName());
+                  } else {
+                    resourceName = writeRequest.getResourceName();
+                    assertThat(resourceName).contains("/compressed-blobs/zstd/");
+                  }
+                }
+                try {
+                  writeRequest.getData().writeTo(baos);
+                  if (writeRequest.getFinishWrite()) {
+                    baos.close();
+                  }
+                } catch (IOException e) {
+                  throw new AssertionError("I/O error on ByteArrayOutputStream.", e);
+                }
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                fail("onError should never be called.");
+              }
+
+              @Override
+              public void onCompleted() {
+                byte[] data = baos.toByteArray();
+                try {
+                  ZstdInputStream zis = new ZstdInputStream(new ByteArrayInputStream(data));
+                  byte[] decompressed = ByteString.readFrom(zis).toByteArray();
+                  zis.close();
+                  Digest digest = DIGEST_UTIL.compute(decompressed);
+
+                  assertThat(blob).hasLength(decompressed.length);
+                  assertThat(resourceName).isNotNull();
+                  assertThat(resourceName)
+                      .endsWith(String.format("/%s/%s", digest.getHash(), digest.getSizeBytes()));
+
+                  numUploads.incrementAndGet();
+                } catch (IOException e) {
+                  throw new AssertionError("Failed decompressing data.", e);
+                } finally {
+                  WriteResponse response =
+                      WriteResponse.newBuilder().setCommittedSize(data.length).build();
+
+                  streamObserver.onNext(response);
+                  streamObserver.onCompleted();
+                }
+              }
+            };
+          }
+        });
+
+    Chunker chunker =
+        Chunker.builder().setInput(blob).setCompressed(true).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    uploader.uploadBlob(context, hash, chunker, true);
+
+    // This test should not have triggered any retries.
+    Mockito.verifyNoInteractions(mockBackoff);
+
+    blockUntilInternalStateConsistent(uploader);
+
+    assertThat(numUploads.get()).isEqualTo(1);
+  }
+
   private static class NoopStreamObserver implements StreamObserver<WriteRequest> {
     @Override
     public void onNext(WriteRequest writeRequest) {}