Add CallCredentialsProvider and refresh credentials in ByteStreamUploader

Users may get authentication error in the mid of a long remote build due to credentials timeout. This PR:
1. Add `CallCredentialsProvider` which can be used by gRPC clients to refresh credentials.
2. Update `ByteStreamUploader.java` to use the provider and refresh credentials on authentication error.

The next step would be updating other places where using `CallCrendentials` currently (e.g. `RemoteCacheClient.java`) to use this provider and refresh credentials when necessary.

Closes #12106.

PiperOrigin-RevId: 332394306
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index 8bed33c..dcd9975 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -15,6 +15,7 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.fail;
 
 import build.bazel.remote.execution.v2.Digest;
@@ -33,12 +34,14 @@
 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
 import com.google.devtools.build.lib.remote.util.TestUtils;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
 import com.google.devtools.build.lib.vfs.DigestHashFunction;
 import com.google.protobuf.ByteString;
 import io.grpc.BindableService;
+import io.grpc.CallCredentials;
 import io.grpc.Context;
 import io.grpc.ManagedChannel;
 import io.grpc.Metadata;
@@ -72,6 +75,7 @@
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -146,7 +150,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -223,7 +227,11 @@
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =
         new ByteStreamUploader(
-            INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            3,
+            retrier);
 
     byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
     new Random().nextBytes(blob);
@@ -339,7 +347,11 @@
         TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);
     ByteStreamUploader uploader =
         new ByteStreamUploader(
-            INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 1, retrier);
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            1,
+            retrier);
 
     byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
     new Random().nextBytes(blob);
@@ -397,7 +409,11 @@
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =
         new ByteStreamUploader(
-            INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            3,
+            retrier);
 
     byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
     new Random().nextBytes(blob);
@@ -467,7 +483,11 @@
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =
         new ByteStreamUploader(
-            INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            3,
+            retrier);
 
     byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
     new Random().nextBytes(blob);
@@ -504,7 +524,11 @@
         TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
     ByteStreamUploader uploader =
         new ByteStreamUploader(
-            INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            CallCredentialsProvider.NO_CREDENTIALS,
+            3,
+            retrier);
 
     byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
     new Random().nextBytes(blob);
@@ -547,7 +571,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -585,7 +609,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -702,7 +726,7 @@
                 InProcessChannelBuilder.forName("Server for " + this.getClass())
                     .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
                     .build()),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -765,7 +789,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -833,7 +857,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -868,7 +892,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -935,7 +959,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -975,7 +999,7 @@
         new ByteStreamUploader(
             /* instanceName */ null,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -1022,7 +1046,7 @@
         new ByteStreamUploader(
             /* instanceName */ null,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -1060,7 +1084,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -1141,7 +1165,7 @@
         new ByteStreamUploader(
             INSTANCE_NAME,
             new ReferenceCountedChannel(channel),
-            null, /* timeout seconds */
+            CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
             60,
             retrier);
 
@@ -1199,6 +1223,158 @@
     withEmptyMetadata.detach(prevContext);
   }
 
+  @Test
+  public void unauthenticatedErrorShouldNotBeRetried() throws Exception {
+    Context prevContext = withEmptyMetadata.attach();
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(
+            () -> mockBackoff, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService);
+
+    AtomicInteger refreshTimes = new AtomicInteger();
+    CallCredentialsProvider callCredentialsProvider =
+        new CallCredentialsProvider() {
+          @Nullable
+          @Override
+          public CallCredentials getCallCredentials() {
+            return null;
+          }
+
+          @Override
+          public void refresh() throws IOException {
+            refreshTimes.incrementAndGet();
+          }
+        };
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            callCredentialsProvider, /* timeout seconds */
+            60,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    AtomicInteger numUploads = new AtomicInteger();
+    serviceRegistry.addService(
+        new ByteStreamImplBase() {
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            numUploads.incrementAndGet();
+
+            streamObserver.onError(Status.UNAUTHENTICATED.asException());
+            return new NoopStreamObserver();
+          }
+        });
+
+    assertThrows(
+        IOException.class,
+        () -> {
+          uploader.uploadBlob(hash, chunker, true);
+        });
+
+    assertThat(refreshTimes.get()).isEqualTo(1);
+    assertThat(numUploads.get()).isEqualTo(2);
+
+    // This test should not have triggered any retries.
+    Mockito.verifyZeroInteractions(mockBackoff);
+
+    blockUntilInternalStateConsistent(uploader);
+
+    withEmptyMetadata.detach(prevContext);
+  }
+
+  @Test
+  public void shouldRefreshCredentialsOnAuthenticationError() throws Exception {
+    Context prevContext = withEmptyMetadata.attach();
+    RemoteRetrier retrier =
+        TestUtils.newRemoteRetrier(
+            () -> mockBackoff, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService);
+
+    AtomicInteger refreshTimes = new AtomicInteger();
+    CallCredentialsProvider callCredentialsProvider =
+        new CallCredentialsProvider() {
+          @Nullable
+          @Override
+          public CallCredentials getCallCredentials() {
+            return null;
+          }
+
+          @Override
+          public void refresh() throws IOException {
+            refreshTimes.incrementAndGet();
+          }
+        };
+    ByteStreamUploader uploader =
+        new ByteStreamUploader(
+            INSTANCE_NAME,
+            new ReferenceCountedChannel(channel),
+            callCredentialsProvider, /* timeout seconds */
+            60,
+            retrier);
+
+    byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+    new Random().nextBytes(blob);
+
+    Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+    HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+    AtomicInteger numUploads = new AtomicInteger();
+    serviceRegistry.addService(
+        new ByteStreamImplBase() {
+          @Override
+          public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+            numUploads.incrementAndGet();
+
+            if (refreshTimes.get() == 0) {
+              streamObserver.onError(Status.UNAUTHENTICATED.asException());
+              return new NoopStreamObserver();
+            }
+
+            return new StreamObserver<WriteRequest>() {
+              long nextOffset = 0;
+
+              @Override
+              public void onNext(WriteRequest writeRequest) {
+                nextOffset += writeRequest.getData().size();
+                boolean lastWrite = blob.length == nextOffset;
+                assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite);
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                fail("onError should never be called.");
+              }
+
+              @Override
+              public void onCompleted() {
+                assertThat(nextOffset).isEqualTo(blob.length);
+
+                WriteResponse response =
+                    WriteResponse.newBuilder().setCommittedSize(nextOffset).build();
+                streamObserver.onNext(response);
+                streamObserver.onCompleted();
+              }
+            };
+          }
+        });
+
+    uploader.uploadBlob(hash, chunker, true);
+
+    assertThat(refreshTimes.get()).isEqualTo(1);
+    assertThat(numUploads.get()).isEqualTo(2);
+
+    // This test should not have triggered any retries.
+    Mockito.verifyZeroInteractions(mockBackoff);
+
+    blockUntilInternalStateConsistent(uploader);
+
+    withEmptyMetadata.detach(prevContext);
+  }
+
   private static class NoopStreamObserver implements StreamObserver<WriteRequest> {
     @Override
     public void onNext(WriteRequest writeRequest) {