Use CallCredentialsProvider and refresh credentials when needed for GrpcCacheClient and GrpcRemoteExecutor
Follow up on #12106, this PR make changes to `GrpcCacheClient` and `GrpcRemoteExecutor` to use `CallCredentialsProvider` and refresh credentials when needed to avoid unauthenticated error during a long remote build.
Closes #12156.
PiperOrigin-RevId: 336619591
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
index a9d5808..a69a4df 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -36,6 +36,7 @@
import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
@@ -352,29 +353,18 @@
AtomicLong committedOffset = new AtomicLong(0);
ListenableFuture<Void> callFuture =
- callAndQueryOnFailureWithRetrier(ctx, committedOffset, progressiveBackoff);
-
- callFuture =
- Futures.catchingAsync(
- callFuture,
- Exception.class,
- (e) -> {
- Status status = Status.fromThrowable(e);
- if (status != null
- && (status.getCode() == Code.UNAUTHENTICATED
- || status.getCode() == Code.PERMISSION_DENIED)) {
- try {
- callCredentialsProvider.refresh();
- } catch (IOException ioe) {
- e.addSuppressed(ioe);
- return Futures.immediateFailedFuture(e);
- }
- return callAndQueryOnFailureWithRetrier(ctx, committedOffset, progressiveBackoff);
- } else {
- return Futures.immediateFailedFuture(e);
- }
- },
- MoreExecutors.directExecutor());
+ Utils.refreshIfUnauthenticatedAsync(
+ () ->
+ retrier.executeAsync(
+ () -> {
+ if (committedOffset.get() < chunker.getSize()) {
+ return ctx.call(
+ () -> callAndQueryOnFailure(committedOffset, progressiveBackoff));
+ }
+ return Futures.immediateFuture(null);
+ },
+ progressiveBackoff),
+ callCredentialsProvider);
return Futures.transformAsync(
callFuture,
@@ -399,18 +389,6 @@
.withDeadlineAfter(callTimeoutSecs, SECONDS);
}
- private ListenableFuture<Void> callAndQueryOnFailureWithRetrier(
- Context ctx, AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
- return retrier.executeAsync(
- () -> {
- if (committedOffset.get() < chunker.getSize()) {
- return ctx.call(() -> callAndQueryOnFailure(committedOffset, progressiveBackoff));
- }
- return Futures.immediateFuture(null);
- },
- progressiveBackoff);
- }
-
private ListenableFuture<Void> callAndQueryOnFailure(
AtomicLong committedOffset, ProgressiveBackoff progressiveBackoff) {
return Futures.catchingAsync(
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index ac8e48a..13c3a45 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -42,6 +42,7 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.remote.RemoteRetrier.ProgressiveBackoff;
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
@@ -53,7 +54,6 @@
import com.google.devtools.build.lib.remote.util.Utils;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
-import io.grpc.CallCredentials;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.Status.Code;
@@ -72,7 +72,7 @@
/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
@ThreadSafe
public class GrpcCacheClient implements RemoteCacheClient, MissingDigestsFinder {
- private final CallCredentials credentials;
+ private final CallCredentialsProvider callCredentialsProvider;
private final ReferenceCountedChannel channel;
private final RemoteOptions options;
private final DigestUtil digestUtil;
@@ -85,12 +85,12 @@
@VisibleForTesting
public GrpcCacheClient(
ReferenceCountedChannel channel,
- CallCredentials credentials,
+ CallCredentialsProvider callCredentialsProvider,
RemoteOptions options,
RemoteRetrier retrier,
DigestUtil digestUtil,
ByteStreamUploader uploader) {
- this.credentials = credentials;
+ this.callCredentialsProvider = callCredentialsProvider;
this.channel = channel;
this.options = options;
this.digestUtil = digestUtil;
@@ -121,28 +121,28 @@
private ContentAddressableStorageFutureStub casFutureStub() {
return ContentAddressableStorageGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withCallCredentials(credentials)
+ .withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}
private ByteStreamStub bsAsyncStub() {
return ByteStreamGrpc.newStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withCallCredentials(credentials)
+ .withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}
private ActionCacheBlockingStub acBlockingStub() {
return ActionCacheGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withCallCredentials(credentials)
+ .withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}
private ActionCacheFutureStub acFutureStub() {
return ActionCacheGrpc.newFutureStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withCallCredentials(credentials)
+ .withCallCredentials(callCredentialsProvider.getCallCredentials())
.withDeadlineAfter(options.remoteTimeout.getSeconds(), TimeUnit.SECONDS);
}
@@ -216,7 +216,9 @@
private ListenableFuture<FindMissingBlobsResponse> getMissingDigests(
FindMissingBlobsRequest request) {
Context ctx = Context.current();
- return retrier.executeAsync(() -> ctx.call(() -> casFutureStub().findMissingBlobs(request)));
+ return Utils.refreshIfUnauthenticatedAsync(
+ () -> retrier.executeAsync(() -> ctx.call(() -> casFutureStub().findMissingBlobs(request))),
+ callCredentialsProvider);
}
private ListenableFuture<ActionResult> handleStatus(ListenableFuture<ActionResult> download) {
@@ -242,23 +244,29 @@
.setInlineStdout(inlineOutErr)
.build();
Context ctx = Context.current();
- return retrier.executeAsync(
- () -> ctx.call(() -> handleStatus(acFutureStub().getActionResult(request))));
+ return Utils.refreshIfUnauthenticatedAsync(
+ () ->
+ retrier.executeAsync(
+ () -> ctx.call(() -> handleStatus(acFutureStub().getActionResult(request)))),
+ callCredentialsProvider);
}
@Override
public void uploadActionResult(ActionKey actionKey, ActionResult actionResult)
throws IOException, InterruptedException {
try {
- retrier.execute(
+ Utils.refreshIfUnauthenticated(
() ->
- acBlockingStub()
- .updateActionResult(
- UpdateActionResultRequest.newBuilder()
- .setInstanceName(options.remoteInstanceName)
- .setActionDigest(actionKey.getDigest())
- .setActionResult(actionResult)
- .build()));
+ retrier.execute(
+ () ->
+ acBlockingStub()
+ .updateActionResult(
+ UpdateActionResultRequest.newBuilder()
+ .setInstanceName(options.remoteInstanceName)
+ .setActionDigest(actionKey.getDigest())
+ .setActionResult(actionResult)
+ .build())),
+ callCredentialsProvider);
} catch (StatusRuntimeException e) {
throw new IOException(e);
}
@@ -287,11 +295,19 @@
Context ctx = Context.current();
AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
- return Futures.catchingAsync(
- retrier.executeAsync(
+ ListenableFuture<Void> downloadFuture =
+ Utils.refreshIfUnauthenticatedAsync(
() ->
- ctx.call(() -> requestRead(offset, progressiveBackoff, digest, out, hashSupplier)),
- progressiveBackoff),
+ retrier.executeAsync(
+ () ->
+ ctx.call(
+ () ->
+ requestRead(offset, progressiveBackoff, digest, out, hashSupplier)),
+ progressiveBackoff),
+ callCredentialsProvider);
+
+ return Futures.catchingAsync(
+ downloadFuture,
StatusRuntimeException.class,
(e) -> Futures.immediateFailedFuture(new IOException(e)),
MoreExecutors.directExecutor());
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
index 573d3b9..f89c461 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -20,12 +20,12 @@
import build.bazel.remote.execution.v2.ExecutionGrpc.ExecutionBlockingStub;
import build.bazel.remote.execution.v2.WaitExecutionRequest;
import com.google.common.base.Preconditions;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
-import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
+import com.google.devtools.build.lib.remote.util.Utils;
import com.google.longrunning.Operation;
import com.google.rpc.Status;
-import io.grpc.CallCredentials;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
@@ -39,27 +39,24 @@
class GrpcRemoteExecutor {
private final ReferenceCountedChannel channel;
- private final CallCredentials callCredentials;
+ private final CallCredentialsProvider callCredentialsProvider;
private final RemoteRetrier retrier;
private final AtomicBoolean closed = new AtomicBoolean();
- private final RemoteOptions options;
public GrpcRemoteExecutor(
ReferenceCountedChannel channel,
- @Nullable CallCredentials callCredentials,
- RemoteRetrier retrier,
- RemoteOptions options) {
+ CallCredentialsProvider callCredentialsProvider,
+ RemoteRetrier retrier) {
this.channel = channel;
- this.callCredentials = callCredentials;
+ this.callCredentialsProvider = callCredentialsProvider;
this.retrier = retrier;
- this.options = options;
}
private ExecutionBlockingStub execBlockingStub() {
return ExecutionGrpc.newBlockingStub(channel)
.withInterceptors(TracingMetadataUtils.attachMetadataFromContextInterceptor())
- .withCallCredentials(callCredentials);
+ .withCallCredentials(callCredentialsProvider.getCallCredentials());
}
private void handleStatus(Status statusProto, @Nullable ExecuteResponse resp) {
@@ -141,81 +138,96 @@
final AtomicBoolean waitExecution =
new AtomicBoolean(false); // Whether we should call WaitExecution.
try {
- return retrier.execute(
- () -> {
- // Retry calls to Execute()/WaitExecute() "infinitely" if the server terminates one of
- // them status OK and an Operation that does not have done=True set. This is legal
- // according to the remote execution protocol i.e. if the execution takes longer
- // than a connection timeout. This is not an error condition and is thus handled
- // outside of the retrier.
- while (true) {
- final Iterator<Operation> replies;
- if (waitExecution.get()) {
- WaitExecutionRequest wr =
- WaitExecutionRequest.newBuilder().setName(operation.get().getName()).build();
- replies = execBlockingStub().waitExecution(wr);
- } else {
- replies = execBlockingStub().execute(request);
- }
- try {
- while (replies.hasNext()) {
- Operation o = replies.next();
- operation.set(o);
- waitExecution.set(!operation.get().getDone());
+ return Utils.refreshIfUnauthenticated(
+ () ->
+ retrier.execute(
+ () -> {
+ // Retry calls to Execute()/WaitExecute() "infinitely" if the server terminates
+ // one of
+ // them status OK and an Operation that does not have done=True set. This is
+ // legal
+ // according to the remote execution protocol i.e. if the execution takes longer
+ // than a connection timeout. This is not an error condition and is thus handled
+ // outside of the retrier.
+ while (true) {
+ final Iterator<Operation> replies;
+ if (waitExecution.get()) {
+ WaitExecutionRequest wr =
+ WaitExecutionRequest.newBuilder()
+ .setName(operation.get().getName())
+ .build();
+ replies = execBlockingStub().waitExecution(wr);
+ } else {
+ replies = execBlockingStub().execute(request);
+ }
+ try {
+ while (replies.hasNext()) {
+ Operation o = replies.next();
+ operation.set(o);
+ waitExecution.set(!operation.get().getDone());
- // Update execution progress to the caller.
- //
- // After called `execute` above, the action is actually waiting for an available
- // gRPC connection to be sent. Once we get a reply from server, we know the
- // connection is up and indicate to the caller the fact by forwarding the
- // `operation`.
- //
- // The accurate execution status of the action relies on the server
- // implementation:
- // 1. Server can reply the accurate status in `operation.metadata.stage`;
- // 2. Server may send a reply without metadata. In this case, we assume the
- // action is accepted by the server and will be executed ASAP;
- // 3. Server may execute the action silently and send a reply once it is done.
- if (receiver != null) {
- receiver.onNextOperation(o);
- }
+ // Update execution progress to the caller.
+ //
+ // After called `execute` above, the action is actually waiting for an
+ // available
+ // gRPC connection to be sent. Once we get a reply from server, we know
+ // the
+ // connection is up and indicate to the caller the fact by forwarding the
+ // `operation`.
+ //
+ // The accurate execution status of the action relies on the server
+ // implementation:
+ // 1. Server can reply the accurate status in
+ // `operation.metadata.stage`;
+ // 2. Server may send a reply without metadata. In this case, we assume
+ // the
+ // action is accepted by the server and will be executed ASAP;
+ // 3. Server may execute the action silently and send a reply once it is
+ // done.
+ if (receiver != null) {
+ receiver.onNextOperation(o);
+ }
- ExecuteResponse r = getOperationResponse(o);
- if (r != null) {
- return r;
- }
- }
- // The operation completed successfully but without a result.
- if (!waitExecution.get()) {
- throw new IOException(
- String.format(
- "Remote server error: execution request for %s terminated with no"
- + " result.",
- operation.get().getName()));
- }
- } catch (StatusRuntimeException e) {
- if (e.getStatus().getCode() == Code.NOT_FOUND) {
- // Operation was lost on the server. Retry Execute.
- waitExecution.set(false);
- }
- throw e;
- } finally {
- // The blocking streaming call closes correctly only when trailers and a Status
- // are received from the server so that onClose() is called on this call's
- // CallListener. Under normal circumstances (no cancel/errors), these are
- // guaranteed to be sent by the server only if replies.hasNext() has been called
- // after all replies from the stream have been consumed.
- try {
- while (replies.hasNext()) {
- replies.next();
- }
- } catch (StatusRuntimeException e) {
- // Cleanup: ignore exceptions, because the meaningful errors have already been
- // propagated.
- }
- }
- }
- });
+ ExecuteResponse r = getOperationResponse(o);
+ if (r != null) {
+ return r;
+ }
+ }
+ // The operation completed successfully but without a result.
+ if (!waitExecution.get()) {
+ throw new IOException(
+ String.format(
+ "Remote server error: execution request for %s terminated with no"
+ + " result.",
+ operation.get().getName()));
+ }
+ } catch (StatusRuntimeException e) {
+ if (e.getStatus().getCode() == Code.NOT_FOUND) {
+ // Operation was lost on the server. Retry Execute.
+ waitExecution.set(false);
+ }
+ throw e;
+ } finally {
+ // The blocking streaming call closes correctly only when trailers and a
+ // Status
+ // are received from the server so that onClose() is called on this call's
+ // CallListener. Under normal circumstances (no cancel/errors), these are
+ // guaranteed to be sent by the server only if replies.hasNext() has been
+ // called
+ // after all replies from the stream have been consumed.
+ try {
+ while (replies.hasNext()) {
+ replies.next();
+ }
+ } catch (StatusRuntimeException e) {
+ // Cleanup: ignore exceptions, because the meaningful errors have already
+ // been
+ // propagated.
+ }
+ }
+ }
+ }),
+ callCredentialsProvider);
} catch (StatusRuntimeException e) {
throw new IOException(e);
}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
index 5c52356..b4f309a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -489,7 +489,7 @@
RemoteCacheClient cacheClient =
new GrpcCacheClient(
cacheChannel.retain(),
- credentials,
+ callCredentialsProvider,
remoteOptions,
retrier,
digestUtil,
@@ -516,7 +516,7 @@
retryScheduler,
Retrier.ALLOW_ALL_CALLS);
GrpcRemoteExecutor remoteExecutor =
- new GrpcRemoteExecutor(execChannel.retain(), credentials, execRetrier, remoteOptions);
+ new GrpcRemoteExecutor(execChannel.retain(), callCredentialsProvider, execRetrier);
execChannel.release();
RemoteExecutionCache remoteCache =
new RemoteExecutionCache(cacheClient, remoteOptions, digestUtil);
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/BUILD b/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
index 4a0d8bb..2b88772 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/BUILD
@@ -17,8 +17,8 @@
"//src/main/java/com/google/devtools/build/lib/actions",
"//src/main/java/com/google/devtools/build/lib/actions:artifacts",
"//src/main/java/com/google/devtools/build/lib/actions:execution_requirements",
- "//src/main/java/com/google/devtools/build/lib/actions:file_metadata",
"//src/main/java/com/google/devtools/build/lib/analysis:blaze_version_info",
+ "//src/main/java/com/google/devtools/build/lib/authandtls",
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//src/main/java/com/google/devtools/build/lib/remote/common",
"//src/main/java/com/google/devtools/build/lib/remote/options",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
index d6f2ffc..1565373 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/util/Utils.java
@@ -15,7 +15,10 @@
import build.bazel.remote.execution.v2.ActionResult;
import build.bazel.remote.execution.v2.Digest;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.AsyncCallable;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
@@ -26,6 +29,7 @@
import com.google.devtools.build.lib.actions.SpawnMetrics;
import com.google.devtools.build.lib.actions.SpawnResult;
import com.google.devtools.build.lib.actions.SpawnResult.Status;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
import com.google.devtools.build.lib.remote.common.RemoteCacheClient.ActionKey;
import com.google.devtools.build.lib.remote.options.RemoteOutputsMode;
@@ -40,12 +44,13 @@
import java.io.OutputStream;
import java.util.Collection;
import java.util.Collections;
+import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import javax.annotation.Nullable;
/** Utility methods for the remote package. * */
-public class Utils {
+public final class Utils {
private Utils() {}
@@ -197,4 +202,73 @@
return contents;
}
}
+
+ /**
+ * Call an asynchronous code block. If the block throws unauthenticated error, refresh the
+ * credentials using {@link CallCredentialsProvider} and call it again.
+ *
+ * <p>If any other exception thrown by the code block, it will be caught and wrapped in the
+ * returned {@link ListenableFuture}.
+ */
+ public static <V> ListenableFuture<V> refreshIfUnauthenticatedAsync(
+ AsyncCallable<V> call, CallCredentialsProvider callCredentialsProvider) {
+ Preconditions.checkNotNull(call);
+ Preconditions.checkNotNull(callCredentialsProvider);
+
+ try {
+ return Futures.catchingAsync(
+ call.call(),
+ Throwable.class,
+ (e) -> refreshIfUnauthenticatedAsyncOnException(e, call, callCredentialsProvider),
+ MoreExecutors.directExecutor());
+ } catch (Throwable t) {
+ return refreshIfUnauthenticatedAsyncOnException(t, call, callCredentialsProvider);
+ }
+ }
+
+ private static <V> ListenableFuture<V> refreshIfUnauthenticatedAsyncOnException(
+ Throwable t, AsyncCallable<V> call, CallCredentialsProvider callCredentialsProvider) {
+ io.grpc.Status status = io.grpc.Status.fromThrowable(t);
+ if (status != null
+ && (status.getCode() == io.grpc.Status.Code.UNAUTHENTICATED
+ || status.getCode() == io.grpc.Status.Code.PERMISSION_DENIED)) {
+ try {
+ callCredentialsProvider.refresh();
+ return call.call();
+ } catch (Throwable tt) {
+ t.addSuppressed(tt);
+ }
+ }
+
+ return Futures.immediateFailedFuture(t);
+ }
+
+ /** Same as {@link #refreshIfUnauthenticatedAsync} but calling a synchronous code block. */
+ public static <V> V refreshIfUnauthenticated(
+ Callable<V> call, CallCredentialsProvider callCredentialsProvider)
+ throws IOException, InterruptedException {
+ Preconditions.checkNotNull(call);
+ Preconditions.checkNotNull(callCredentialsProvider);
+
+ try {
+ return call.call();
+ } catch (Exception e) {
+ io.grpc.Status status = io.grpc.Status.fromThrowable(e);
+ if (status != null
+ && (status.getCode() == io.grpc.Status.Code.UNAUTHENTICATED
+ || status.getCode() == io.grpc.Status.Code.PERMISSION_DENIED)) {
+ try {
+ callCredentialsProvider.refresh();
+ return call.call();
+ } catch (Exception ex) {
+ e.addSuppressed(ex);
+ }
+ }
+
+ Throwables.throwIfInstanceOf(e, IOException.class);
+ Throwables.throwIfInstanceOf(e, InterruptedException.class);
+ Throwables.throwIfUnchecked(e);
+ throw new AssertionError(e);
+ }
+ }
}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
index 82d8ab0..6ab2b63 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java
@@ -237,7 +237,7 @@
remoteOptions.remoteTimeout.getSeconds(),
retrier);
return new GrpcCacheClient(
- channel.retain(), creds, remoteOptions, retrier, DIGEST_UTIL, uploader);
+ channel.retain(), callCredentialsProvider, remoteOptions, retrier, DIGEST_UTIL, uploader);
}
private static byte[] downloadBlob(GrpcCacheClient cacheClient, Digest digest)
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
index f078105..c2af204 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
@@ -89,7 +89,6 @@
import com.google.rpc.PreconditionFailure;
import com.google.rpc.PreconditionFailure.Violation;
import io.grpc.BindableService;
-import io.grpc.CallCredentials;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerCall;
@@ -259,11 +258,10 @@
.directExecutor()
.build());
GrpcRemoteExecutor executor =
- new GrpcRemoteExecutor(channel.retain(), null, retrier, remoteOptions);
+ new GrpcRemoteExecutor(channel.retain(), CallCredentialsProvider.NO_CREDENTIALS, retrier);
CallCredentialsProvider callCredentialsProvider =
GoogleAuthUtils.newCallCredentialsProvider(
GoogleAuthUtils.newCredentials(Options.getDefaults(AuthAndTLSOptions.class)));
- CallCredentials creds = callCredentialsProvider.getCallCredentials();
ByteStreamUploader uploader =
new ByteStreamUploader(
remoteOptions.remoteInstanceName,
@@ -272,7 +270,13 @@
remoteOptions.remoteTimeout.getSeconds(),
retrier);
GrpcCacheClient cacheProtocol =
- new GrpcCacheClient(channel.retain(), creds, remoteOptions, retrier, DIGEST_UTIL, uploader);
+ new GrpcCacheClient(
+ channel.retain(),
+ callCredentialsProvider,
+ remoteOptions,
+ retrier,
+ DIGEST_UTIL,
+ uploader);
RemoteExecutionCache remoteCache =
new RemoteExecutionCache(cacheProtocol, remoteOptions, DIGEST_UTIL);
client =
diff --git a/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java b/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
index ad1b4c0..6dcb2d5 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/UtilsTest.java
@@ -14,10 +14,19 @@
package com.google.devtools.build.lib.remote;
import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import com.google.common.util.concurrent.Futures;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.remote.util.Utils;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -36,4 +45,140 @@
assertThat(Utils.grpcAwareErrorMessage(ioError)).isEqualTo("io error");
assertThat(Utils.grpcAwareErrorMessage(wrappedGrpcError)).isEqualTo("ABORTED: grpc error");
}
+
+ @Test
+ public void refreshIfUnauthenticatedAsync_unauthenticated_shouldRefresh() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ Utils.refreshIfUnauthenticatedAsync(
+ () -> {
+ if (callTimes.getAndIncrement() == 0) {
+ throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+ }
+ return Futures.immediateFuture(null);
+ },
+ callCredentialsProvider)
+ .get();
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticatedAsync_unauthenticatedFuture_shouldRefresh() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ Utils.refreshIfUnauthenticatedAsync(
+ () -> {
+ if (callTimes.getAndIncrement() == 0) {
+ return Futures.immediateFailedFuture(
+ new StatusRuntimeException(Status.UNAUTHENTICATED));
+ }
+ return Futures.immediateFuture(null);
+ },
+ callCredentialsProvider)
+ .get();
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticatedAsync_permissionDenied_shouldRefresh() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ Utils.refreshIfUnauthenticated(
+ () -> {
+ if (callTimes.getAndIncrement() == 0) {
+ throw new StatusRuntimeException(Status.PERMISSION_DENIED);
+ }
+ return Futures.immediateFuture(null);
+ },
+ callCredentialsProvider)
+ .get();
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticatedAsync_cantRefresh_shouldRefreshOnceAndFail()
+ throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ assertThrows(
+ ExecutionException.class,
+ () -> {
+ Utils.refreshIfUnauthenticatedAsync(
+ () -> {
+ callTimes.getAndIncrement();
+ throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+ },
+ callCredentialsProvider)
+ .get();
+ });
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticated_unauthenticated_shouldRefresh() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ Utils.refreshIfUnauthenticated(
+ () -> {
+ if (callTimes.getAndIncrement() == 0) {
+ throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+ }
+ return null;
+ },
+ callCredentialsProvider);
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticated_permissionDenied_shouldRefresh() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ Utils.refreshIfUnauthenticated(
+ () -> {
+ if (callTimes.getAndIncrement() == 0) {
+ throw new StatusRuntimeException(Status.PERMISSION_DENIED);
+ }
+ return null;
+ },
+ callCredentialsProvider);
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
+
+ @Test
+ public void refreshIfUnauthenticated_cantRefresh_shouldRefreshOnceAndFail() throws Exception {
+ CallCredentialsProvider callCredentialsProvider = mock(CallCredentialsProvider.class);
+ AtomicInteger callTimes = new AtomicInteger();
+
+ assertThrows(
+ StatusRuntimeException.class,
+ () -> {
+ Utils.refreshIfUnauthenticated(
+ () -> {
+ callTimes.getAndIncrement();
+ throw new StatusRuntimeException(Status.UNAUTHENTICATED);
+ },
+ callCredentialsProvider);
+ });
+
+ assertThat(callTimes.get()).isEqualTo(2);
+ verify(callCredentialsProvider, times(1)).refresh();
+ }
}