Use CallCredentialsProvider and refresh credentials when needed for GrpcCacheClient and GrpcRemoteExecutor

Follow up on #12106, this PR make changes to `GrpcCacheClient` and `GrpcRemoteExecutor` to use `CallCredentialsProvider` and refresh credentials when needed to avoid unauthenticated error during a long remote build.

Closes #12156.

PiperOrigin-RevId: 336619591
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
index a9d5808..a69a4df 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -36,6 +36,7 @@
 import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
 import io.grpc.CallOptions;
 import io.grpc.Channel;
 import io.grpc.ClientCall;
@@ -352,29 +353,18 @@
       AtomicLong committedOffset = new AtomicLong(0);
 
       ListenableFuture<Void> callFuture =
-          callAndQueryOnFailureWithRetrier(ctx, committedOffset, progressiveBackoff);
-
-      callFuture =
-          Futures.catchingAsync(
-              callFuture,
-              Exception.class,
-              (e) -> {
-                Status status = Status.fromThrowable(e);
-                if (status != null
-                    && (status.getCode() == Code.UNAUTHENTICATED
-                        || status.getCode() == Code.PERMISSION_DENIED)) {
-                  try {
-                    callCredentialsProvider.refresh();
-                  } catch (IOException ioe) {
-                    e.addSuppressed(ioe);
-                    return Futures.immediateFailedFuture(e);
-                  }
-                  return callAndQueryOnFailureWithRetrier(ctx, committedOffset, progressiveBackoff);
-                } else {
-                  return Futures.immediateFailedFuture(e);
-                }
-              },
-              MoreExecutors.directExecutor());
+          Utils.refreshIfUnauthenticatedAsync(
+              () ->
+                  retrier.executeAsync(
+                      () -> {
+                        if (committedOffset.get() < chunker.getSize()) {
+                          return ctx.call(
+                              () -> callAndQueryOnFailure(committedOffset, progressiveBackoff));
+                        }
+                        return Futures.immediateFuture(null);
+                      },
+                      progressiveBackoff),
+              callCredentialsProvider);
 
       return Futures.transformAsync(
           callFuture,
@@ -399,18 +389,6 @@
           .withDeadlineAfter(callTimeoutSecs, SECONDS);
     }
 
-    private ListenableFuture<Void> callAndQueryOnFailureWithRetrier(
-        Context ctx, AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
-      return retrier.executeAsync(
-          () -> {
-            if (committedOffset.get() < chunker.getSize()) {
-              return ctx.call(() -> callAndQueryOnFailure(committedOffset, progressiveBackoff));
-            }
-            return Futures.immediateFuture(null);
-          },
-          progressiveBackoff);
-    }
-
     private ListenableFuture<Void> callAndQueryOnFailure(
         AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
       return Futures.catchingAsync(
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index ac8e48a..13c3a45 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -42,6 +42,7 @@
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.SettableFuture;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
 import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
@@ -53,7 +54,6 @@
 import com.google.devtools.build.lib.remote.util.Utils;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.protobuf.ByteString;
-import io.grpc.CallCredentials;
 import io.grpc.Context;
 import io.grpc.Status;
 import io.grpc.Status.Code;
@@ -72,7 +72,7 @@
 /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
 @ThreadSafe
 public class GrpcCacheClient implements RemoteCacheClient, MissingDigestsFinder {
-  private final CallCredentials credentials;
+  private final CallCredentialsProvider callCredentialsProvider;
   private final ReferenceCountedChannel channel;
   private final RemoteOptions options;
   private final DigestUtil digestUtil;
@@ -85,12 +85,12 @@
   @VisibleForTesting
   public GrpcCacheClient(
       ReferenceCountedChannel channel,
-      CallCredentials credentials,
+      CallCredentialsProvider callCredentialsProvider,
       RemoteOptions options,
       RemoteRetrier retrier,
       DigestUtil digestUtil,
       ByteStreamUploader uploader) {
-    this.credentials = credentials;
+    this.callCredentialsProvider = callCredentialsProvider;
     this.channel = channel;
     this.options = options;
     this.digestUtil = digestUtil;
@@ -121,28 +121,28 @@
   private ContentAddressableStorageFutureStub casFutureStub() {
     return ContentAddressableStorageGrpc.newFutureStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withCallCredentials(credentials)
+        .withCallCredentials(callCredentialsProvider.getCallCredentials())
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
   private ByteStreamStub bsAsyncStub() {
     return ByteStreamGrpc.newStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withCallCredentials(credentials)
+        .withCallCredentials(callCredentialsProvider.getCallCredentials())
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
   private ActionCacheBlockingStub acBlockingStub() {
     return ActionCacheGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withCallCredentials(credentials)
+        .withCallCredentials(callCredentialsProvider.getCallCredentials())
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
   private ActionCacheFutureStub acFutureStub() {
     return ActionCacheGrpc.newFutureStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withCallCredentials(credentials)
+        .withCallCredentials(callCredentialsProvider.getCallCredentials())
         .withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
   }
 
@@ -216,7 +216,9 @@
   private ListenableFuture<FindMissingBlobsResponse> getMissingDigests(
       FindMissingBlobsRequest request) {
     Context ctx = Context.current();
-    return retrier.executeAsync(() -> ctx.call(() -> casFutureStub().findMissingBlobs(request)));
+    return Utils.refreshIfUnauthenticatedAsync(
+        () -> retrier.executeAsync(() -> ctx.call(() -> casFutureStub().findMissingBlobs(request))),
+        callCredentialsProvider);
   }
 
   private ListenableFuture<ActionResult> handleStatus(ListenableFuture<ActionResult> download) {
@@ -242,23 +244,29 @@
             .setInlineStdout(inlineOutErr)
             .build();
     Context ctx = Context.current();
-    return retrier.executeAsync(
-        () -> ctx.call(() -> handleStatus(acFutureStub().getActionResult(request))));
+    return Utils.refreshIfUnauthenticatedAsync(
+        () ->
+            retrier.executeAsync(
+                () -> ctx.call(() -> handleStatus(acFutureStub().getActionResult(request)))),
+        callCredentialsProvider);
   }
 
   @Override
   public void uploadActionResult(ActionKey actionKey, ActionResult actionResult)
       throws IOException, InterruptedException {
     try {
-      retrier.execute(
+      Utils.refreshIfUnauthenticated(
           () ->
-              acBlockingStub()
-                  .updateActionResult(
-                      UpdateActionResultRequest.newBuilder()
-                          .setInstanceName(options.remoteInstanceName)
-                          .setActionDigest(actionKey.getDigest())
-                          .setActionResult(actionResult)
-                          .build()));
+              retrier.execute(
+                  () ->
+                      acBlockingStub()
+                          .updateActionResult(
+                              UpdateActionResultRequest.newBuilder()
+                                  .setInstanceName(options.remoteInstanceName)
+                                  .setActionDigest(actionKey.getDigest())
+                                  .setActionResult(actionResult)
+                                  .build())),
+          callCredentialsProvider);
     } catch (StatusRuntimeException e) {
       throw new IOException(e);
     }
@@ -287,11 +295,19 @@
     Context ctx = Context.current();
     AtomicLong offset = new AtomicLong(0);
     ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
-    return Futures.catchingAsync(
-        retrier.executeAsync(
+    ListenableFuture<Void> downloadFuture =
+        Utils.refreshIfUnauthenticatedAsync(
             () ->
-                ctx.call(() -> requestRead(offset, progressiveBackoff, digest, out, hashSupplier)),
-            progressiveBackoff),
+                retrier.executeAsync(
+                    () ->
+                        ctx.call(
+                            () ->
+                                requestRead(offset, progressiveBackoff, digest, out, hashSupplier)),
+                    progressiveBackoff),
+            callCredentialsProvider);
+
+    return Futures.catchingAsync(
+        downloadFuture,
         StatusRuntimeException.class,
         (e) -> Futures.immediateFailedFuture(new IOException(e)),
         MoreExecutors.directExecutor());
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
index 573d3b9..f89c461 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -20,12 +20,12 @@
 import build.bazel.remote.execution.v2.ExecutionGrpc.ExecutionBlockingStub;
 import build.bazel.remote.execution.v2.WaitExecutionRequest;
 import com.google.common.base.Preconditions;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
-import com.google.devtools.build.lib.remote.options.RemoteOptions;
 import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
 import com.google.longrunning.Operation;
 import com.google.rpc.Status;
-import io.grpc.CallCredentials;
 import io.grpc.Status.Code;
 import io.grpc.StatusRuntimeException;
 import java.io.IOException;
@@ -39,27 +39,24 @@
 class GrpcRemoteExecutor {
 
   private final ReferenceCountedChannel channel;
-  private final CallCredentials callCredentials;
+  private final CallCredentialsProvider callCredentialsProvider;
   private final RemoteRetrier retrier;
 
   private final AtomicBoolean closed = new AtomicBoolean();
-  private final RemoteOptions options;
 
   public GrpcRemoteExecutor(
       ReferenceCountedChannel channel,
-      @Nullable CallCredentials callCredentials,
-      RemoteRetrier retrier,
-      RemoteOptions options) {
+      CallCredentialsProvider callCredentialsProvider,
+      RemoteRetrier retrier) {
     this.channel = channel;
-    this.callCredentials = callCredentials;
+    this.callCredentialsProvider = callCredentialsProvider;
     this.retrier = retrier;
-    this.options = options;
   }
 
   private ExecutionBlockingStub execBlockingStub() {
     return ExecutionGrpc.newBlockingStub(channel)
         .withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
-        .withCallCredentials(callCredentials);
+        .withCallCredentials(callCredentialsProvider.getCallCredentials());
   }
 
   private void handleStatus(Status statusProto, @Nullable ExecuteResponse resp) {
@@ -141,81 +138,96 @@
     final AtomicBoolean waitExecution =
         new AtomicBoolean(false); // Whether we should call WaitExecution.
     try {
-      return retrier.execute(
-          () -> {
-            // Retry calls to Execute()/WaitExecute() "infinitely" if the server terminates one of
-            // them status OK and an Operation that does not have done=True set. This is legal
-            // according to the remote execution protocol i.e. if the execution takes longer
-            // than a connection timeout. This is not an error condition and is thus handled
-            // outside of the retrier.
-            while (true) {
-              final Iterator<Operation> replies;
-              if (waitExecution.get()) {
-                WaitExecutionRequest wr =
-                    WaitExecutionRequest.newBuilder().setName(operation.get().getName()).build();
-                replies = execBlockingStub().waitExecution(wr);
-              } else {
-                replies = execBlockingStub().execute(request);
-              }
-              try {
-                while (replies.hasNext()) {
-                  Operation o = replies.next();
-                  operation.set(o);
-                  waitExecution.set(!operation.get().getDone());
+      return Utils.refreshIfUnauthenticated(
+          () ->
+              retrier.execute(
+                  () -> {
+                    // Retry calls to Execute()/WaitExecute() "infinitely" if the server terminates
+                    // one of
+                    // them status OK and an Operation that does not have done=True set. This is
+                    // legal
+                    // according to the remote execution protocol i.e. if the execution takes longer
+                    // than a connection timeout. This is not an error condition and is thus handled
+                    // outside of the retrier.
+                    while (true) {
+                      final Iterator<Operation> replies;
+                      if (waitExecution.get()) {
+                        WaitExecutionRequest wr =
+                            WaitExecutionRequest.newBuilder()
+                                .setName(operation.get().getName())
+                                .build();
+                        replies = execBlockingStub().waitExecution(wr);
+                      } else {
+                        replies = execBlockingStub().execute(request);
+                      }
+                      try {
+                        while (replies.hasNext()) {
+                          Operation o = replies.next();
+                          operation.set(o);
+                          waitExecution.set(!operation.get().getDone());
 
-                  // Update execution progress to the caller.
-                  //
-                  // After called `execute` above, the action is actually waiting for an available
-                  // gRPC connection to be sent. Once we get a reply from server, we know the
-                  // connection is up and indicate to the caller the fact by forwarding the
-                  // `operation`.
-                  //
-                  // The accurate execution status of the action relies on the server
-                  // implementation:
-                  //   1. Server can reply the accurate status in `operation.metadata.stage`;
-                  //   2. Server may send a reply without metadata. In this case, we assume the
-                  //      action is accepted by the server and will be executed ASAP;
-                  //   3. Server may execute the action silently and send a reply once it is done.
-                  if (receiver != null) {
-                    receiver.onNextOperation(o);
-                  }
+                          // Update execution progress to the caller.
+                          //
+                          // After called `execute` above, the action is actually waiting for an
+                          // available
+                          // gRPC connection to be sent. Once we get a reply from server, we know
+                          // the
+                          // connection is up and indicate to the caller the fact by forwarding the
+                          // `operation`.
+                          //
+                          // The accurate execution status of the action relies on the server
+                          // implementation:
+                          //   1. Server can reply the accurate status in
+                          // `operation.metadata.stage`;
+                          //   2. Server may send a reply without metadata. In this case, we assume
+                          // the
+                          //      action is accepted by the server and will be executed ASAP;
+                          //   3. Server may execute the action silently and send a reply once it is
+                          // done.
+                          if (receiver != null) {
+                            receiver.onNextOperation(o);
+                          }
 
-                  ExecuteResponse r = getOperationResponse(o);
-                  if (r != null) {
-                    return r;
-                  }
-                }
-                // The operation completed successfully but without a result.
-                if (!waitExecution.get()) {
-                  throw new IOException(
-                      String.format(
-                          "Remote server error: execution request for %s terminated with no"
-                              + " result.",
-                          operation.get().getName()));
-                }
-              } catch (StatusRuntimeException e) {
-                if (e.getStatus().getCode() == Code.NOT_FOUND) {
-                  // Operation was lost on the server. Retry Execute.
-                  waitExecution.set(false);
-                }
-                throw e;
-              } finally {
-                // The blocking streaming call closes correctly only when trailers and a Status
-                // are received from the server so that onClose() is called on this call's
-                // CallListener. Under normal circumstances (no cancel/errors), these are
-                // guaranteed to be sent by the server only if replies.hasNext() has been called
-                // after all replies from the stream have been consumed.
-                try {
-                  while (replies.hasNext()) {
-                    replies.next();
-                  }
-                } catch (StatusRuntimeException e) {
-                  // Cleanup: ignore exceptions, because the meaningful errors have already been
-                  // propagated.
-                }
-              }
-            }
-          });
+                          ExecuteResponse r = getOperationResponse(o);
+                          if (r != null) {
+                            return r;
+                          }
+                        }
+                        // The operation completed successfully but without a result.
+                        if (!waitExecution.get()) {
+                          throw new IOException(
+                              String.format(
+                                  "Remote server error: execution request for %s terminated with no"
+                                      + " result.",
+                                  operation.get().getName()));
+                        }
+                      } catch (StatusRuntimeException e) {
+                        if (e.getStatus().getCode() == Code.NOT_FOUND) {
+                          // Operation was lost on the server. Retry Execute.
+                          waitExecution.set(false);
+                        }
+                        throw e;
+                      } finally {
+                        // The blocking streaming call closes correctly only when trailers and a
+                        // Status
+                        // are received from the server so that onClose() is called on this call's
+                        // CallListener. Under normal circumstances (no cancel/errors), these are
+                        // guaranteed to be sent by the server only if replies.hasNext() has been
+                        // called
+                        // after all replies from the stream have been consumed.
+                        try {
+                          while (replies.hasNext()) {
+                            replies.next();
+                          }
+                        } catch (StatusRuntimeException e) {
+                          // Cleanup: ignore exceptions, because the meaningful errors have already
+                          // been
+                          // propagated.
+                        }
+                      }
+                    }
+                  }),
+          callCredentialsProvider);
     } catch (StatusRuntimeException e) {
       throw new IOException(e);
     }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
index 5c52356..b4f309a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -489,7 +489,7 @@
     RemoteCacheClient cacheClient =
         new GrpcCacheClient(
             cacheChannel.retain(),
-            credentials,
+            callCredentialsProvider,
             remoteOptions,
             retrier,
             digestUtil,
@@ -516,7 +516,7 @@
               retryScheduler,
               Retrier.ALLOW_ALL_CALLS);
       GrpcRemoteExecutor remoteExecutor =
-          new GrpcRemoteExecutor(execChannel.retain(), credentials, execRetrier, remoteOptions);
+          new GrpcRemoteExecutor(execChannel.retain(), callCredentialsProvider, execRetrier);
       execChannel.release();
       RemoteExecutionCache remoteCache =
           new RemoteExecutionCache(cacheClient, remoteOptions, digestUtil);
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/BUILD b/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
index 4a0d8bb..2b88772 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
@@ -17,8 +17,8 @@
         "//src/main/java/com/google/devtools/build/lib/actions",
         "//src/main/java/com/google/devtools/build/lib/actions:artifacts",
         "//src/main/java/com/google/devtools/build/lib/actions:execution_requirements",
-        "//src/main/java/com/google/devtools/build/lib/actions:file_metadata",
         "//src/main/java/com/google/devtools/build/lib/analysis:blaze_version_info",
+        "//src/main/java/com/google/devtools/build/lib/authandtls",
         "//src/main/java/com/google/devtools/build/lib/concurrent",
         "//src/main/java/com/google/devtools/build/lib/remote/common",
         "//src/main/java/com/google/devtools/build/lib/remote/options",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
index d6f2ffc..1565373 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
@@ -15,7 +15,10 @@
 
 import build.bazel.remote.execution.v2.ActionResult;
 import build.bazel.remote.execution.v2.Digest;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.AsyncCallable;
 import com.google.common.util.concurrent.FluentFuture;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -26,6 +29,7 @@
 import com.google.devtools.build.lib.actions.SpawnMetrics;
 import com.google.devtools.build.lib.actions.SpawnResult;
 import com.google.devtools.build.lib.actions.SpawnResult.Status;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
 import com.google.devtools.build.lib.remote.common.RemoteCacheClient.ActionKey;
 import com.google.devtools.build.lib.remote.options.RemoteOutputsMode;
@@ -40,12 +44,13 @@
 import java.io.OutputStream;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
 import java.util.function.BiFunction;
 import javax.annotation.Nullable;
 
 /** Utility methods for the remote package. * */
-public class Utils {
+public final class Utils {
 
   private Utils() {}
 
@@ -197,4 +202,73 @@
       return contents;
     }
   }
+
+  /**
+   * Call an asynchronous code block. If the block throws unauthenticated error, refresh the
+   * credentials using {@link CallCredentialsProvider} and call it again.
+   *
+   * <p>If any other exception thrown by the code block, it will be caught and wrapped in the
+   * returned {@link ListenableFuture}.
+   */
+  public static <V> ListenableFuture<V> refreshIfUnauthenticatedAsync(
+      AsyncCallable<V> call, CallCredentialsProvider callCredentialsProvider) {
+    Preconditions.checkNotNull(call);
+    Preconditions.checkNotNull(callCredentialsProvider);
+
+    try {
+      return Futures.catchingAsync(
+          call.call(),
+          Throwable.class,
+          (e) -> refreshIfUnauthenticatedAsyncOnException(e, call, callCredentialsProvider),
+          MoreExecutors.directExecutor());
+    } catch (Throwable t) {
+      return refreshIfUnauthenticatedAsyncOnException(t, call, callCredentialsProvider);
+    }
+  }
+
+  private static <V> ListenableFuture<V> refreshIfUnauthenticatedAsyncOnException(
+      Throwable t, AsyncCallable<V> call, CallCredentialsProvider callCredentialsProvider) {
+    io.grpc.Status status = io.grpc.Status.fromThrowable(t);
+    if (status != null
+        && (status.getCode() == io.grpc.Status.Code.UNAUTHENTICATED
+            || status.getCode() == io.grpc.Status.Code.PERMISSION_DENIED)) {
+      try {
+        callCredentialsProvider.refresh();
+        return call.call();
+      } catch (Throwable tt) {
+        t.addSuppressed(tt);
+      }
+    }
+
+    return Futures.immediateFailedFuture(t);
+  }
+
+  /** Same as {@link #refreshIfUnauthenticatedAsync} but calling a synchronous code block. */
+  public static <V> V refreshIfUnauthenticated(
+      Callable<V> call, CallCredentialsProvider callCredentialsProvider)
+      throws IOException, InterruptedException {
+    Preconditions.checkNotNull(call);
+    Preconditions.checkNotNull(callCredentialsProvider);
+
+    try {
+      return call.call();
+    } catch (Exception e) {
+      io.grpc.Status status = io.grpc.Status.fromThrowable(e);
+      if (status != null
+          && (status.getCode() == io.grpc.Status.Code.UNAUTHENTICATED
+              || status.getCode() == io.grpc.Status.Code.PERMISSION_DENIED)) {
+        try {
+          callCredentialsProvider.refresh();
+          return call.call();
+        } catch (Exception ex) {
+          e.addSuppressed(ex);
+        }
+      }
+
+      Throwables.throwIfInstanceOf(e, IOException.class);
+      Throwables.throwIfInstanceOf(e, InterruptedException.class);
+      Throwables.throwIfUnchecked(e);
+      throw new AssertionError(e);
+    }
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
index 82d8ab0..6ab2b63 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
@@ -237,7 +237,7 @@
             remoteOptions.remoteTimeout.getSeconds(),
             retrier);
     return new GrpcCacheClient(
-        channel.retain(), creds, remoteOptions, retrier, DIGEST_UTIL, uploader);
+        channel.retain(), callCredentialsProvider, remoteOptions, retrier, DIGEST_UTIL, uploader);
   }
 
   private static byte[] downloadBlob(GrpcCacheClient cacheClient, Digest digest)
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
index f078105..c2af204 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
@@ -89,7 +89,6 @@
 import com.google.rpc.PreconditionFailure;
 import com.google.rpc.PreconditionFailure.Violation;
 import io.grpc.BindableService;
-import io.grpc.CallCredentials;
 import io.grpc.Metadata;
 import io.grpc.Server;
 import io.grpc.ServerCall;
@@ -259,11 +258,10 @@
                 .directExecutor()
                 .build());
     GrpcRemoteExecutor executor =
-        new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
+        new GrpcRemoteExecutor(channel.retain(), CallCredentialsProvider.NO_CREDENTIALS, retrier);
     CallCredentialsProvider callCredentialsProvider =
         GoogleAuthUtils.newCallCredentialsProvider(
             GoogleAuthUtils.newCredentials(Options.getDefaults(AuthAndTLSOptions.class)));
-    CallCredentials creds = callCredentialsProvider.getCallCredentials();
     ByteStreamUploader uploader =
         new ByteStreamUploader(
             remoteOptions.remoteInstanceName,
@@ -272,7 +270,13 @@
             remoteOptions.remoteTimeout.getSeconds(),
             retrier);
     GrpcCacheClient cacheProtocol =
-        new GrpcCacheClient(channel.retain(), creds, remoteOptions, retrier, DIGEST_UTIL, uploader);
+        new GrpcCacheClient(
+            channel.retain(),
+            callCredentialsProvider,
+            remoteOptions,
+            retrier,
+            DIGEST_UTIL,
+            uploader);
     RemoteExecutionCache remoteCache =
         new RemoteExecutionCache(cacheProtocol, remoteOptions, DIGEST_UTIL);
     client =
diff --git a/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java b/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
index ad1b4c0..6dcb2d5 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
@@ -14,10 +14,19 @@
 package com.google.devtools.build.lib.remote;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
+import com.google.common.util.concurrent.Futures;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
 import com.google.devtools.build.lib.remote.util.Utils;
 import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
 import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -36,4 +45,140 @@
     assertThat(Utils.grpcAwareErrorMessage(ioError)).isEqualTo("io error");
     assertThat(Utils.grpcAwareErrorMessage(wrappedGrpcError)).isEqualTo("ABORTED: grpc error");
   }
+
+  @Test
+  public void refreshIfUnauthenticatedAsync_unauthenticated_shouldRefresh() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    Utils.refreshIfUnauthenticatedAsync(
+            () -> {
+              if (callTimes.getAndIncrement() == 0) {
+                throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+              }
+              return Futures.immediateFuture(null);
+            },
+            callCredentialsProvider)
+        .get();
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticatedAsync_unauthenticatedFuture_shouldRefresh() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    Utils.refreshIfUnauthenticatedAsync(
+            () -> {
+              if (callTimes.getAndIncrement() == 0) {
+                return Futures.immediateFailedFuture(
+                    new StatusRuntimeException(Status.UNAUTHENTICATED));
+              }
+              return Futures.immediateFuture(null);
+            },
+            callCredentialsProvider)
+        .get();
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticatedAsync_permissionDenied_shouldRefresh() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    Utils.refreshIfUnauthenticated(
+            () -> {
+              if (callTimes.getAndIncrement() == 0) {
+                throw new StatusRuntimeException(Status.PERMISSION_DENIED);
+              }
+              return Futures.immediateFuture(null);
+            },
+            callCredentialsProvider)
+        .get();
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticatedAsync_cantRefresh_shouldRefreshOnceAndFail()
+      throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    assertThrows(
+        ExecutionException.class,
+        () -> {
+          Utils.refreshIfUnauthenticatedAsync(
+                  () -> {
+                    callTimes.getAndIncrement();
+                    throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+                  },
+                  callCredentialsProvider)
+              .get();
+        });
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticated_unauthenticated_shouldRefresh() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    Utils.refreshIfUnauthenticated(
+        () -> {
+          if (callTimes.getAndIncrement() == 0) {
+            throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+          }
+          return null;
+        },
+        callCredentialsProvider);
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticated_permissionDenied_shouldRefresh() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    Utils.refreshIfUnauthenticated(
+        () -> {
+          if (callTimes.getAndIncrement() == 0) {
+            throw new StatusRuntimeException(Status.PERMISSION_DENIED);
+          }
+          return null;
+        },
+        callCredentialsProvider);
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
+
+  @Test
+  public void refreshIfUnauthenticated_cantRefresh_shouldRefreshOnceAndFail() throws Exception {
+    CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+    AtomicInteger callTimes = new AtomicInteger();
+
+    assertThrows(
+        StatusRuntimeException.class,
+        () -> {
+          Utils.refreshIfUnauthenticated(
+              () -> {
+                callTimes.getAndIncrement();
+                throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+              },
+              callCredentialsProvider);
+        });
+
+    assertThat(callTimes.get()).isEqualTo(2);
+    verify(callCredentialsProvider, times(1)).refresh();
+  }
 }