Add CallCredentialsProvider and refresh credentials in ByteStreamUploader
Users may get authentication error in the mid of a long remote build due to credentials timeout. This PR:
1. Add `CallCredentialsProvider` which can be used by gRPC clients to refresh credentials.
2. Update `ByteStreamUploader.java` to use the provider and refresh credentials on authentication error.
The next step would be updating other places where using `CallCrendentials` currently (e.g. `RemoteCacheClient.java`) to use this provider and refresh credentials when necessary.
Closes #12106.
PiperOrigin-RevId: 332394306
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
index 8bed33c..dcd9975 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -15,6 +15,7 @@
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
import build.bazel.remote.execution.v2.Digest;
@@ -33,12 +34,14 @@
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
+import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.TestUtils;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.vfs.DigestHashFunction;
import com.google.protobuf.ByteString;
import io.grpc.BindableService;
+import io.grpc.CallCredentials;
import io.grpc.Context;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
@@ -72,6 +75,7 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -146,7 +150,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -223,7 +227,11 @@
TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
ByteStreamUploader uploader =
new ByteStreamUploader(
- INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ CallCredentialsProvider.NO_CREDENTIALS,
+ 3,
+ retrier);
byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
new Random().nextBytes(blob);
@@ -339,7 +347,11 @@
TestUtils.newRemoteRetrier(() -> new FixedBackoff(1, 0), (e) -> true, retryService);
ByteStreamUploader uploader =
new ByteStreamUploader(
- INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 1, retrier);
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ CallCredentialsProvider.NO_CREDENTIALS,
+ 1,
+ retrier);
byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
new Random().nextBytes(blob);
@@ -397,7 +409,11 @@
TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
ByteStreamUploader uploader =
new ByteStreamUploader(
- INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ CallCredentialsProvider.NO_CREDENTIALS,
+ 3,
+ retrier);
byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
new Random().nextBytes(blob);
@@ -467,7 +483,11 @@
TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
ByteStreamUploader uploader =
new ByteStreamUploader(
- INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ CallCredentialsProvider.NO_CREDENTIALS,
+ 3,
+ retrier);
byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
new Random().nextBytes(blob);
@@ -504,7 +524,11 @@
TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService);
ByteStreamUploader uploader =
new ByteStreamUploader(
- INSTANCE_NAME, new ReferenceCountedChannel(channel), null, 3, retrier);
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ CallCredentialsProvider.NO_CREDENTIALS,
+ 3,
+ retrier);
byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
new Random().nextBytes(blob);
@@ -547,7 +571,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -585,7 +609,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -702,7 +726,7 @@
InProcessChannelBuilder.forName("Server for " + this.getClass())
.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
.build()),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -765,7 +789,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -833,7 +857,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -868,7 +892,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -935,7 +959,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -975,7 +999,7 @@
new ByteStreamUploader(
/* instanceName */ null,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -1022,7 +1046,7 @@
new ByteStreamUploader(
/* instanceName */ null,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -1060,7 +1084,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -1141,7 +1165,7 @@
new ByteStreamUploader(
INSTANCE_NAME,
new ReferenceCountedChannel(channel),
- null, /* timeout seconds */
+ CallCredentialsProvider.NO_CREDENTIALS, /* timeout seconds */
60,
retrier);
@@ -1199,6 +1223,158 @@
withEmptyMetadata.detach(prevContext);
}
+ @Test
+ public void unauthenticatedErrorShouldNotBeRetried() throws Exception {
+ Context prevContext = withEmptyMetadata.attach();
+ RemoteRetrier retrier =
+ TestUtils.newRemoteRetrier(
+ () -> mockBackoff, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService);
+
+ AtomicInteger refreshTimes = new AtomicInteger();
+ CallCredentialsProvider callCredentialsProvider =
+ new CallCredentialsProvider() {
+ @Nullable
+ @Override
+ public CallCredentials getCallCredentials() {
+ return null;
+ }
+
+ @Override
+ public void refresh() throws IOException {
+ refreshTimes.incrementAndGet();
+ }
+ };
+ ByteStreamUploader uploader =
+ new ByteStreamUploader(
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ callCredentialsProvider, /* timeout seconds */
+ 60,
+ retrier);
+
+ byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+ new Random().nextBytes(blob);
+
+ Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+ HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+ AtomicInteger numUploads = new AtomicInteger();
+ serviceRegistry.addService(
+ new ByteStreamImplBase() {
+ @Override
+ public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+ numUploads.incrementAndGet();
+
+ streamObserver.onError(Status.UNAUTHENTICATED.asException());
+ return new NoopStreamObserver();
+ }
+ });
+
+ assertThrows(
+ IOException.class,
+ () -> {
+ uploader.uploadBlob(hash, chunker, true);
+ });
+
+ assertThat(refreshTimes.get()).isEqualTo(1);
+ assertThat(numUploads.get()).isEqualTo(2);
+
+ // This test should not have triggered any retries.
+ Mockito.verifyZeroInteractions(mockBackoff);
+
+ blockUntilInternalStateConsistent(uploader);
+
+ withEmptyMetadata.detach(prevContext);
+ }
+
+ @Test
+ public void shouldRefreshCredentialsOnAuthenticationError() throws Exception {
+ Context prevContext = withEmptyMetadata.attach();
+ RemoteRetrier retrier =
+ TestUtils.newRemoteRetrier(
+ () -> mockBackoff, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService);
+
+ AtomicInteger refreshTimes = new AtomicInteger();
+ CallCredentialsProvider callCredentialsProvider =
+ new CallCredentialsProvider() {
+ @Nullable
+ @Override
+ public CallCredentials getCallCredentials() {
+ return null;
+ }
+
+ @Override
+ public void refresh() throws IOException {
+ refreshTimes.incrementAndGet();
+ }
+ };
+ ByteStreamUploader uploader =
+ new ByteStreamUploader(
+ INSTANCE_NAME,
+ new ReferenceCountedChannel(channel),
+ callCredentialsProvider, /* timeout seconds */
+ 60,
+ retrier);
+
+ byte[] blob = new byte[CHUNK_SIZE * 2 + 1];
+ new Random().nextBytes(blob);
+
+ Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build();
+ HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash());
+
+ AtomicInteger numUploads = new AtomicInteger();
+ serviceRegistry.addService(
+ new ByteStreamImplBase() {
+ @Override
+ public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) {
+ numUploads.incrementAndGet();
+
+ if (refreshTimes.get() == 0) {
+ streamObserver.onError(Status.UNAUTHENTICATED.asException());
+ return new NoopStreamObserver();
+ }
+
+ return new StreamObserver<WriteRequest>() {
+ long nextOffset = 0;
+
+ @Override
+ public void onNext(WriteRequest writeRequest) {
+ nextOffset += writeRequest.getData().size();
+ boolean lastWrite = blob.length == nextOffset;
+ assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite);
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ fail("onError should never be called.");
+ }
+
+ @Override
+ public void onCompleted() {
+ assertThat(nextOffset).isEqualTo(blob.length);
+
+ WriteResponse response =
+ WriteResponse.newBuilder().setCommittedSize(nextOffset).build();
+ streamObserver.onNext(response);
+ streamObserver.onCompleted();
+ }
+ };
+ }
+ });
+
+ uploader.uploadBlob(hash, chunker, true);
+
+ assertThat(refreshTimes.get()).isEqualTo(1);
+ assertThat(numUploads.get()).isEqualTo(2);
+
+ // This test should not have triggered any retries.
+ Mockito.verifyZeroInteractions(mockBackoff);
+
+ blockUntilInternalStateConsistent(uploader);
+
+ withEmptyMetadata.detach(prevContext);
+ }
+
private static class NoopStreamObserver implements StreamObserver<WriteRequest> {
@Override
public void onNext(WriteRequest writeRequest) {