remote: add a ByteStreamBuildEventArtifactUploader This change allows local files referenced by the BEP/BES protocol to be uploaded to a ByteStream gRPC service. The ByteStreamUploader is now implicitly also used by the BES module which has a different lifecycle than the remote module. We introduce reference counting to ensure that the channel is closed after its no longer needed. This also fixes a bug where we currently leak one socket per remote build until the Bazel server is shut down. RELNOTES: None PiperOrigin-RevId: 204275316
diff --git a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java index 21b6588..a6c51da 100644 --- a/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java +++ b/src/main/java/com/google/devtools/build/lib/authandtls/GoogleAuthUtils.java
@@ -19,6 +19,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.grpc.CallCredentials; +import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.grpc.auth.MoreCallCredentials; import io.grpc.netty.GrpcSslContexts; @@ -42,10 +43,12 @@ * * @throws IOException in case the channel can't be constructed. */ - public static ManagedChannel newChannel(String target, AuthAndTLSOptions options) + public static ManagedChannel newChannel(String target, AuthAndTLSOptions options, + ClientInterceptor... interceptors) throws IOException { Preconditions.checkNotNull(target); Preconditions.checkNotNull(options); + Preconditions.checkNotNull(interceptors); final SslContext sslContext = options.tlsEnabled ? createSSlContext(options.tlsCertificate) : null; @@ -54,7 +57,8 @@ NettyChannelBuilder builder = NettyChannelBuilder.forTarget(target) .negotiationType(options.tlsEnabled ? NegotiationType.TLS : NegotiationType.PLAINTEXT) - .loadBalancerFactory(RoundRobinLoadBalancerFactory.getInstance()); + .loadBalancerFactory(RoundRobinLoadBalancerFactory.getInstance()) + .intercept(interceptors); if (sslContext != null) { builder.sslContext(sslContext); if (options.tlsAuthorityOverride != null) {
diff --git a/src/main/java/com/google/devtools/build/lib/buildeventstream/BuildEventArtifactUploaderFactory.java b/src/main/java/com/google/devtools/build/lib/buildeventstream/BuildEventArtifactUploaderFactory.java index d2cc64e..66b55e6 100644 --- a/src/main/java/com/google/devtools/build/lib/buildeventstream/BuildEventArtifactUploaderFactory.java +++ b/src/main/java/com/google/devtools/build/lib/buildeventstream/BuildEventArtifactUploaderFactory.java
@@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package com.google.devtools.build.lib.buildeventstream; import static com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploader.LOCAL_FILES_UPLOADER; @@ -29,3 +30,4 @@ */ BuildEventArtifactUploader create(OptionsProvider options); } +
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploader.java new file mode 100644 index 0000000..3f3308c --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploader.java
@@ -0,0 +1,144 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote; + +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.devtools.build.lib.buildeventstream.BuildEvent.LocalFile; +import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploader; +import com.google.devtools.build.lib.buildeventstream.PathConverter; +import com.google.devtools.build.lib.vfs.Path; +import com.google.devtools.remoteexecution.v1test.Digest; +import io.grpc.Context; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.Nullable; + +/** + * A {@link BuildEventArtifactUploader} backed by {@link ByteStreamUploader}. + */ +class ByteStreamBuildEventArtifactUploader implements BuildEventArtifactUploader { + + private final Context ctx; + private final ByteStreamUploader uploader; + private final String remoteServerInstanceName; + + private final AtomicBoolean shutdown = new AtomicBoolean(); + + ByteStreamBuildEventArtifactUploader( + ByteStreamUploader uploader, String remoteServerName, Context ctx, + @Nullable String remoteInstanceName) { + this.uploader = Preconditions.checkNotNull(uploader); + String remoteServerInstanceName = Preconditions.checkNotNull(remoteServerName); + if (!Strings.isNullOrEmpty(remoteInstanceName)) { + remoteServerInstanceName += "/" + remoteInstanceName; + } + this.ctx = ctx; + this.remoteServerInstanceName = remoteServerInstanceName; + } + + @Override + public ListenableFuture<PathConverter> upload(Map<Path, LocalFile> files) { + if (files.isEmpty()) { + return Futures.immediateFuture(PathConverter.NO_CONVERSION); + } + List<ListenableFuture<PathDigestPair>> uploads = new ArrayList<>(files.size()); + + Context prevCtx = ctx.attach(); + try { + for (Path file : files.keySet()) { + Chunker chunker = new Chunker(file); + Digest digest = chunker.digest(); + ListenableFuture<PathDigestPair> upload = + Futures.transform( + uploader.uploadBlobAsync(chunker, /*forceUpload=*/false), + unused -> new PathDigestPair(file, digest), + MoreExecutors.directExecutor()); + uploads.add(upload); + } + + return Futures.transform(Futures.allAsList(uploads), + (uploadsDone) -> new PathConverterImpl(remoteServerInstanceName, uploadsDone), + MoreExecutors.directExecutor()); + } catch (IOException e) { + return Futures.immediateFailedFuture(e); + } finally { + ctx.detach(prevCtx); + } + } + + @Override + public void shutdown() { + if (shutdown.getAndSet(true)) { + return; + } + uploader.release(); + } + + private static class PathConverterImpl implements PathConverter { + + private final String remoteServerInstanceName; + private final Map<Path, Digest> pathToDigest; + + PathConverterImpl(String remoteServerInstanceName, + List<PathDigestPair> uploads) { + Preconditions.checkNotNull(uploads); + this.remoteServerInstanceName = remoteServerInstanceName; + pathToDigest = new HashMap<>(uploads.size()); + for (PathDigestPair pair : uploads) { + pathToDigest.put(pair.getPath(), pair.getDigest()); + } + } + + @Override + public String apply(Path path) { + Preconditions.checkNotNull(path); + Digest digest = pathToDigest.get(path); + if (digest == null) { + // It's a programming error to reference a file that has not been uploaded. + throw new IllegalStateException( + String.format("Illegal file reference: '%s'", path.getPathString())); + } + return String.format( + "bytestream://%s/blobs/%s/%d", + remoteServerInstanceName, digest.getHash(), digest.getSizeBytes()); + } + } + + private static class PathDigestPair { + + private final Path path; + private final Digest digest; + + PathDigestPair(Path path, Digest digest) { + this.path = path; + this.digest = digest; + } + + public Path getPath() { + return path; + } + + public Digest getDigest() { + return digest; + } + } +}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderFactory.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderFactory.java new file mode 100644 index 0000000..be43302 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderFactory.java
@@ -0,0 +1,47 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote; + +import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploader; +import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploaderFactory; +import com.google.devtools.common.options.OptionsProvider; +import io.grpc.Context; +import javax.annotation.Nullable; + +/** + * A factory for {@link ByteStreamBuildEventArtifactUploader}. + */ +class ByteStreamBuildEventArtifactUploaderFactory implements + BuildEventArtifactUploaderFactory { + + private final ByteStreamUploader uploader; + private final String remoteServerName; + private final Context ctx; + private final @Nullable String remoteInstanceName; + + ByteStreamBuildEventArtifactUploaderFactory( + ByteStreamUploader uploader, String remoteServerName, Context ctx, + @Nullable String remoteInstanceName) { + this.uploader = uploader; + this.remoteServerName = remoteServerName; + this.ctx = ctx; + this.remoteInstanceName = remoteInstanceName; + } + + @Override + public BuildEventArtifactUploader create(OptionsProvider options) { + return new ByteStreamBuildEventArtifactUploader(uploader.retain(), remoteServerName, ctx, + remoteInstanceName); + } +}
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java index 001ba22..42129a4 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java
@@ -26,6 +26,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import com.google.common.base.Throwables; +import com.google.common.hash.HashCode; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -39,11 +41,15 @@ import io.grpc.Context; import io.grpc.Metadata; import io.grpc.Status; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCounted; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -55,20 +61,27 @@ /** * A client implementing the {@code Write} method of the {@code ByteStream} gRPC service. * - * <p>Users must call {@link #shutdown()} before exiting. + * <p>The uploader supports reference counting to easily be shared between components with + * different lifecyles. After instantiation the reference coune is {@code 1}. + * + * See {@link ReferenceCounted} for more information on reference counting. */ -final class ByteStreamUploader { +class ByteStreamUploader extends AbstractReferenceCounted { private static final Logger logger = Logger.getLogger(ByteStreamUploader.class.getName()); private final String instanceName; - private final Channel channel; + private final ReferenceCountedChannel channel; private final CallCredentials callCredentials; private final long callTimeoutSecs; private final RemoteRetrier retrier; private final Object lock = new Object(); + /** Contains the hash codes of already uploaded blobs. **/ + @GuardedBy("lock") + private final Set<HashCode> uploadedBlobs = new HashSet<>(); + @GuardedBy("lock") private final Map<Digest, ListenableFuture<Void>> uploadsInProgress = new HashMap<>(); @@ -89,7 +102,7 @@ */ public ByteStreamUploader( @Nullable String instanceName, - Channel channel, + ReferenceCountedChannel channel, @Nullable CallCredentials callCredentials, long callTimeoutSecs, RemoteRetrier retrier) { @@ -112,11 +125,15 @@ * <p>Trying to upload the same BLOB multiple times concurrently, results in only one upload being * performed. This is transparent to the user of this API. * + * @param chunker the data to upload. + * @param forceUpload if {@code false} the blob is not uploaded if it has previously been + * uploaded, if {@code true} the blob is uploaded. * @throws IOException when reading of the {@link Chunker}s input source fails * @throws RetryException when the upload failed after a retry */ - public void uploadBlob(Chunker chunker) throws IOException, InterruptedException { - uploadBlobs(singletonList(chunker)); + public void uploadBlob(Chunker chunker, boolean forceUpload) throws IOException, + InterruptedException { + uploadBlobs(singletonList(chunker), forceUpload); } /** @@ -131,14 +148,18 @@ * <p>Trying to upload the same BLOB multiple times concurrently, results in only one upload being * performed. This is transparent to the user of this API. * + * @param chunkers the data to upload. + * @param forceUpload if {@code false} the blob is not uploaded if it has previously been + * uploaded, if {@code true} the blob is uploaded. * @throws IOException when reading of the {@link Chunker}s input source fails * @throws RetryException when the upload failed after a retry */ - public void uploadBlobs(Iterable<Chunker> chunkers) throws IOException, InterruptedException { + public void uploadBlobs(Iterable<Chunker> chunkers, boolean forceUpload) throws IOException, + InterruptedException { List<ListenableFuture<Void>> uploads = new ArrayList<>(); for (Chunker chunker : chunkers) { - uploads.add(uploadBlobAsync(chunker)); + uploads.add(uploadBlobAsync(chunker, forceUpload)); } try { @@ -162,9 +183,11 @@ * Cancels all running uploads. The method returns immediately and does NOT wait for the uploads * to be cancelled. * - * <p>This method must be the last method called. + * <p>This method should not be called directly, but will be called implicitly when the + * reference count reaches {@code 0}. */ - public void shutdown() { + @VisibleForTesting + void shutdown() { synchronized (lock) { if (isShutdown) { return; @@ -180,13 +203,33 @@ } } - @VisibleForTesting - ListenableFuture<Void> uploadBlobAsync(Chunker chunker) { + /** + * Uploads a BLOB asynchronously to the remote {@code ByteStream} service. The call returns + * immediately and one can listen to the returned future for the success/failure of the upload. + * + * <p>Uploads are retried according to the specified {@link RemoteRetrier}. Retrying is + * transparent to the user of this API. + * + * <p>Trying to upload the same BLOB multiple times concurrently, results in only one upload being + * performed. This is transparent to the user of this API. + * + * @param chunker the data to upload. + * @param forceUpload if {@code false} the blob is not uploaded if it has previously been + * uploaded, if {@code true} the blob is uploaded. + * @throws IOException when reading of the {@link Chunker}s input source fails + * @throws RetryException when the upload failed after a retry + */ + public ListenableFuture<Void> uploadBlobAsync(Chunker chunker, boolean forceUpload) { Digest digest = checkNotNull(chunker.digest()); + HashCode hash = HashCode.fromString(digest.getHash()); synchronized (lock) { checkState(!isShutdown, "Must not call uploadBlobs after shutdown."); + if (!forceUpload && uploadedBlobs.contains(hash)) { + return Futures.immediateFuture(null); + } + ListenableFuture<Void> inProgress = uploadsInProgress.get(digest); if (inProgress != null) { return inProgress; @@ -197,6 +240,7 @@ () -> { synchronized (lock) { uploadsInProgress.remove(digest); + uploadedBlobs.add(hash); } }, MoreExecutors.directExecutor()); @@ -243,6 +287,27 @@ return currUpload; } + @Override + public ByteStreamUploader retain() { + return (ByteStreamUploader) super.retain(); + } + + @Override + public ByteStreamUploader retain(int increment) { + return (ByteStreamUploader) super.retain(increment); + } + + @Override + protected void deallocate() { + shutdown(); + channel.release(); + } + + @Override + public ReferenceCounted touch(Object o) { + return this; + } + private static class AsyncUpload { private final Channel channel;
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java index 48bd4c8..253ab41 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteCache.java
@@ -49,7 +49,6 @@ import com.google.devtools.remoteexecution.v1test.GetActionResultRequest; import com.google.devtools.remoteexecution.v1test.UpdateActionResultRequest; import io.grpc.CallCredentials; -import io.grpc.Channel; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; @@ -61,30 +60,31 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */ @ThreadSafe public class GrpcRemoteCache extends AbstractRemoteActionCache { private final CallCredentials credentials; - private final Channel channel; + private final ReferenceCountedChannel channel; private final RemoteRetrier retrier; private final ByteStreamUploader uploader; + private AtomicBoolean closed = new AtomicBoolean(); + @VisibleForTesting public GrpcRemoteCache( - Channel channel, + ReferenceCountedChannel channel, CallCredentials credentials, RemoteOptions options, RemoteRetrier retrier, - DigestUtil digestUtil) { + DigestUtil digestUtil, + ByteStreamUploader uploader) { super(options, digestUtil, retrier); this.credentials = credentials; this.channel = channel; this.retrier = retrier; - - uploader = - new ByteStreamUploader( - options.remoteInstanceName, channel, credentials, options.remoteTimeout, retrier); + this.uploader = uploader; } private ContentAddressableStorageBlockingStub casBlockingStub() { @@ -110,7 +110,11 @@ @Override public void close() { - uploader.shutdown(); + if (closed.getAndSet(true)) { + return; + } + uploader.release(); + channel.release(); } public static boolean isRemoteCacheOptions(RemoteOptions options) { @@ -168,7 +172,7 @@ toUpload.add(new Chunker(actionInput, inputFileCache, execRoot, digestUtil)); } } - uploader.uploadBlobs(toUpload); + uploader.uploadBlobs(toUpload, true); } @Override @@ -293,7 +297,7 @@ } if (!filesToUpload.isEmpty()) { - uploader.uploadBlobs(filesToUpload); + uploader.uploadBlobs(filesToUpload, /*forceUpload=*/true); } // TODO(olaola): inline small stdout/stderr here. @@ -317,7 +321,7 @@ Digest digest = digestUtil.compute(file); ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest)); if (!missing.isEmpty()) { - uploader.uploadBlob(new Chunker(file)); + uploader.uploadBlob(new Chunker(file), true); } return digest; } @@ -333,7 +337,7 @@ Digest digest = DigestUtil.getFromInputCache(input, inputCache); ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest)); if (!missing.isEmpty()) { - uploader.uploadBlob(new Chunker(input, inputCache, execRoot, digestUtil)); + uploader.uploadBlob(new Chunker(input, inputCache, execRoot, digestUtil), true); } return digest; } @@ -342,7 +346,7 @@ Digest digest = digestUtil.compute(blob); ImmutableSet<Digest> missing = getMissingDigests(ImmutableList.of(digest)); if (!missing.isEmpty()) { - uploader.uploadBlob(new Chunker(blob, digestUtil)); + uploader.uploadBlob(new Chunker(blob, digestUtil), true); } return digest; }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java index c98384c..858f574 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutor.java
@@ -30,26 +30,29 @@ import com.google.watcher.v1.WatcherGrpc; import com.google.watcher.v1.WatcherGrpc.WatcherBlockingStub; import io.grpc.CallCredentials; -import io.grpc.Channel; +import io.grpc.ManagedChannel; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.protobuf.StatusProto; import java.io.IOException; import java.util.Iterator; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; /** A remote work executor that uses gRPC for communicating the work, inputs and outputs. */ @ThreadSafe class GrpcRemoteExecutor { - private final Channel channel; + private final ManagedChannel channel; private final CallCredentials callCredentials; private final int callTimeoutSecs; private final RemoteRetrier retrier; + private final AtomicBoolean closed = new AtomicBoolean(); + public GrpcRemoteExecutor( - Channel channel, + ManagedChannel channel, @Nullable CallCredentials callCredentials, int callTimeoutSecs, RemoteRetrier retrier) { @@ -73,7 +76,7 @@ .withCallCredentials(callCredentials); } - private void handleStatus(Status statusProto, @Nullable ExecuteResponse resp) throws IOException { + private void handleStatus(Status statusProto, @Nullable ExecuteResponse resp) { if (statusProto.getCode() == Code.OK.value()) { return; } @@ -206,4 +209,11 @@ }); }); } + + public void close() { + if (closed.getAndSet(true)) { + return; + } + channel.shutdown(); + } }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java new file mode 100644 index 0000000..eff9621 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/ReferenceCountedChannel.java
@@ -0,0 +1,125 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote; + +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCounted; +import java.util.concurrent.TimeUnit; + +/** A wrapper around a {@link io.grpc.ManagedChannel} exposing a reference count. + * When instantiated the reference count is 1. {@link ManagedChannel#shutdown()} will be called + * on the wrapped channel when the reference count reaches 0. + * + * See {@link ReferenceCounted} for more information about reference counting. + */ +class ReferenceCountedChannel extends ManagedChannel implements ReferenceCounted { + + private final ManagedChannel channel; + private final AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() { + @Override + protected void deallocate() { + channel.shutdown(); + } + + @Override + public ReferenceCounted touch(Object o) { + return this; + } + }; + + public ReferenceCountedChannel(ManagedChannel channel) { + this.channel = channel; + } + + @Override + public ManagedChannel shutdown() { + throw new UnsupportedOperationException("Don't call shutdown() directly, but use release() " + + "instead."); + } + + @Override + public boolean isShutdown() { + return channel.isShutdown(); + } + + @Override + public boolean isTerminated() { + return channel.isTerminated(); + } + + @Override + public ManagedChannel shutdownNow() { + throw new UnsupportedOperationException("Don't call shutdownNow() directly, but use release() " + + "instead."); + } + + @Override + public boolean awaitTermination(long l, TimeUnit timeUnit) throws InterruptedException { + return channel.awaitTermination(l, timeUnit); + } + + @Override + public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( + MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { + return channel.<RequestT, ResponseT>newCall(methodDescriptor, callOptions); + } + + @Override + public String authority() { + return channel.authority(); + } + + @Override + public int refCnt() { + return referenceCounted.refCnt(); + } + + @Override + public ReferenceCountedChannel retain() { + referenceCounted.retain(); + return this; + } + + @Override + public ReferenceCountedChannel retain(int increment) { + referenceCounted.retain(increment); + return this; + } + + @Override + public ReferenceCounted touch() { + referenceCounted.touch(); + return this; + } + + @Override + public ReferenceCounted touch(Object hint) { + referenceCounted.touch(hint); + return this; + } + + @Override + public boolean release() { + return referenceCounted.release(); + } + + @Override + public boolean release(int decrement) { + return referenceCounted.release(decrement); + } +} \ No newline at end of file
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionContextProvider.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionContextProvider.java index 145ae9b..88bc907 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionContextProvider.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionContextProvider.java
@@ -120,5 +120,8 @@ if (cache != null) { cache.close(); } + if (executor != null) { + executor.close(); + } } }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java index ca1b2b5..7141d35 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteModule.java
@@ -14,23 +14,21 @@ package com.google.devtools.build.lib.remote; -import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.authandtls.AuthAndTLSOptions; import com.google.devtools.build.lib.authandtls.GoogleAuthUtils; -import com.google.devtools.build.lib.buildeventstream.BuildEvent.LocalFile; import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploader; -import com.google.devtools.build.lib.buildeventstream.PathConverter; +import com.google.devtools.build.lib.buildeventstream.BuildEventArtifactUploaderFactory; import com.google.devtools.build.lib.buildtool.BuildRequest; import com.google.devtools.build.lib.events.Event; import com.google.devtools.build.lib.exec.ExecutorBuilder; import com.google.devtools.build.lib.remote.Retrier.RetryException; import com.google.devtools.build.lib.remote.logging.LoggingInterceptor; import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.runtime.BlazeModule; import com.google.devtools.build.lib.runtime.Command; import com.google.devtools.build.lib.runtime.CommandEnvironment; @@ -43,17 +41,19 @@ import com.google.devtools.build.lib.vfs.Path; import com.google.devtools.common.options.OptionsBase; import com.google.devtools.common.options.OptionsProvider; -import com.google.devtools.remoteexecution.v1test.Digest; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.rpc.PreconditionFailure; import com.google.rpc.PreconditionFailure.Violation; -import io.grpc.Channel; -import io.grpc.ClientInterceptors; +import io.grpc.CallCredentials; +import io.grpc.ClientInterceptor; +import io.grpc.Context; +import io.grpc.ManagedChannel; import io.grpc.Status.Code; import io.grpc.protobuf.StatusProto; import java.io.IOException; -import java.util.Map; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.Executors; import java.util.function.Predicate; import java.util.logging.Logger; @@ -63,63 +63,16 @@ private static final Logger logger = Logger.getLogger(RemoteModule.class.getName()); private AsynchronousFileOutputStream rpcLogFile; - @VisibleForTesting - static final class CasPathConverter implements PathConverter { - // Not final; unfortunately, the Bazel startup process requires us to create this object before - // we have the options available, so we have to create it first, and then set the options - // afterwards. At the time of this writing, I believe that we aren't using the PathConverter - // before the options are available, so this should be safe. - // TODO(ulfjack): Change the Bazel startup process to make the options available when we create - // the PathConverter. - RemoteOptions options; - DigestUtil digestUtil; - PathConverter fallbackConverter = new FileUriPathConverter(); - - @Override - public String apply(Path path) { - if (options == null || digestUtil == null || !remoteEnabled(options)) { - return fallbackConverter.apply(path); - } - String server = options.remoteCache; - String remoteInstanceName = options.remoteInstanceName; - try { - Digest digest = digestUtil.compute(path); - return remoteInstanceName.isEmpty() - ? String.format( - "bytestream://%s/blobs/%s/%d", server, digest.getHash(), digest.getSizeBytes()) - : String.format( - "bytestream://%s/%s/blobs/%s/%d", - server, remoteInstanceName, digest.getHash(), digest.getSizeBytes()); - } catch (IOException e) { - // TODO(ulfjack): Don't fail silently! - return fallbackConverter.apply(path); - } - } - } - - private final CasPathConverter converter = new CasPathConverter(); private final ListeningScheduledExecutorService retryScheduler = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1)); private RemoteActionContextProvider actionContextProvider; + private final BuildEventArtifactUploaderFactoryDelegate + buildEventArtifactUploaderFactoryDelegate = new BuildEventArtifactUploaderFactoryDelegate(); + @Override public void serverInit(OptionsProvider startupOptions, ServerBuilder builder) { - builder.addBuildEventArtifactUploaderFactory( - (OptionsProvider options) -> - new BuildEventArtifactUploader() { - - @Override - public ListenableFuture<PathConverter> upload(Map<Path, LocalFile> files) { - // TODO(ulfjack): Actually hook up upload here. - return Futures.immediateFuture(converter); - } - - @Override - public void shutdown() { - // Intentionally left empty. - } - }, - "remote"); + builder.addBuildEventArtifactUploaderFactory(buildEventArtifactUploaderFactoryDelegate, "remote"); } private static final String VIOLATION_TYPE_MISSING = "MISSING"; @@ -179,8 +132,6 @@ AuthAndTLSOptions authAndTlsOptions = env.getOptions().getOptions(AuthAndTLSOptions.class); DigestHashFunction hashFn = env.getRuntime().getFileSystem().getDigestFunction(); DigestUtil digestUtil = new DigestUtil(hashFn); - converter.options = remoteOptions; - converter.digestUtil = digestUtil; // Quit if no remote options specified. if (remoteOptions == null) { @@ -203,10 +154,10 @@ } try { - LoggingInterceptor logger = null; + List<ClientInterceptor> interceptors = new ArrayList<>(); if (!remoteOptions.experimentalRemoteGrpcLog.isEmpty()) { rpcLogFile = new AsynchronousFileOutputStream(remoteOptions.experimentalRemoteGrpcLog); - logger = new LoggingInterceptor(rpcLogFile, env.getRuntime().getClock()); + interceptors.add(new LoggingInterceptor(rpcLogFile, env.getRuntime().getClock())); } final RemoteRetrier executeRetrier; @@ -231,24 +182,42 @@ } else if (enableGrpcCache || remoteOptions.remoteExecutor != null) { // If a remote executor but no remote cache is specified, assume both at the same target. String target = enableGrpcCache ? remoteOptions.remoteCache : remoteOptions.remoteExecutor; - Channel ch = GoogleAuthUtils.newChannel(target, authAndTlsOptions); - if (logger != null) { - ch = ClientInterceptors.intercept(ch, logger); - } - RemoteRetrier retrier = + ReferenceCountedChannel channel = + new ReferenceCountedChannel( + GoogleAuthUtils.newChannel( + target, + authAndTlsOptions, + interceptors.toArray(new ClientInterceptor[0]))); + RemoteRetrier rpcRetrier = new RemoteRetrier( remoteOptions, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryScheduler, Retrier.ALLOW_ALL_CALLS); executeRetrier = createExecuteRetrier(remoteOptions, retryScheduler); + CallCredentials credentials = GoogleAuthUtils.newCallCredentials(authAndTlsOptions); + ByteStreamUploader uploader = + new ByteStreamUploader( + remoteOptions.remoteInstanceName, + channel.retain(), + credentials, + remoteOptions.remoteTimeout, + rpcRetrier); cache = new GrpcRemoteCache( - ch, - GoogleAuthUtils.newCallCredentials(authAndTlsOptions), + channel.retain(), + credentials, remoteOptions, - retrier, - digestUtil); + rpcRetrier, + digestUtil, + uploader.retain()); + Context requestContext = + TracingMetadataUtils.contextWithMetadata(buildRequestId, commandId, "bes-upload"); + buildEventArtifactUploaderFactoryDelegate.init( + new ByteStreamBuildEventArtifactUploaderFactory( + uploader, target, requestContext, remoteOptions.remoteInstanceName)); + uploader.release(); + channel.release(); } else { executeRetrier = null; cache = null; @@ -256,19 +225,20 @@ final GrpcRemoteExecutor executor; if (remoteOptions.remoteExecutor != null) { - Channel ch = GoogleAuthUtils.newChannel(remoteOptions.remoteExecutor, authAndTlsOptions); + ManagedChannel channel = + GoogleAuthUtils.newChannel( + remoteOptions.remoteExecutor, + authAndTlsOptions, + interceptors.toArray(new ClientInterceptor[0])); RemoteRetrier retrier = new RemoteRetrier( remoteOptions, RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryScheduler, Retrier.ALLOW_ALL_CALLS); - if (logger != null) { - ch = ClientInterceptors.intercept(ch, logger); - } executor = new GrpcRemoteExecutor( - ch, + channel, GoogleAuthUtils.newCallCredentials(authAndTlsOptions), remoteOptions.remoteTimeout, retrier); @@ -297,6 +267,7 @@ rpcLogFile = null; } } + buildEventArtifactUploaderFactoryDelegate.reset(); } @Override @@ -319,7 +290,7 @@ || GrpcRemoteCache.isRemoteCacheOptions(options); } - public static RemoteRetrier createExecuteRetrier( + static RemoteRetrier createExecuteRetrier( RemoteOptions options, ListeningScheduledExecutorService retryService) { return new RemoteRetrier( options.experimentalRemoteRetry @@ -329,4 +300,28 @@ retryService, Retrier.ALLOW_ALL_CALLS); } + + private static class BuildEventArtifactUploaderFactoryDelegate + implements BuildEventArtifactUploaderFactory { + + private volatile BuildEventArtifactUploaderFactory uploaderFactory; + + public void init(BuildEventArtifactUploaderFactory uploaderFactory) { + Preconditions.checkState(this.uploaderFactory == null); + this.uploaderFactory = uploaderFactory; + } + + public void reset() { + this.uploaderFactory = null; + } + + @Override + public BuildEventArtifactUploader create(OptionsProvider options) { + BuildEventArtifactUploaderFactory uploaderFactory0 = this.uploaderFactory; + if (uploaderFactory0 == null) { + return BuildEventArtifactUploader.LOCAL_FILES_UPLOADER; + } + return uploaderFactory0.create(options); + } + } }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java index eac9e5a..5955d5d 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/TracingMetadataUtils.java
@@ -14,6 +14,7 @@ package com.google.devtools.build.lib.remote.util; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.devtools.build.lib.analysis.BlazeVersionInfo; import com.google.devtools.build.lib.remote.util.DigestUtil.ActionKey; import com.google.devtools.remoteexecution.v1test.RequestMetadata; @@ -51,17 +52,42 @@ */ public static Context contextWithMetadata( String buildRequestId, String commandId, ActionKey actionKey) { - RequestMetadata metadata = + Preconditions.checkNotNull(buildRequestId); + Preconditions.checkNotNull(commandId); + Preconditions.checkNotNull(actionKey); + RequestMetadata.Builder metadata = RequestMetadata.newBuilder() .setCorrelatedInvocationsId(buildRequestId) - .setToolInvocationId(commandId) - .setActionId(actionKey.getDigest().getHash()) - .setToolDetails( - ToolDetails.newBuilder() - .setToolName("bazel") - .setToolVersion(BlazeVersionInfo.instance().getVersion())) + .setToolInvocationId(commandId); + metadata.setActionId(actionKey.getDigest().getHash()); + metadata.setToolDetails(ToolDetails.newBuilder() + .setToolName("bazel") + .setToolVersion(BlazeVersionInfo.instance().getVersion())) .build(); - return Context.current().withValue(CONTEXT_KEY, metadata); + return Context.current().withValue(CONTEXT_KEY, metadata.build()); + } + + /** + * Returns a new gRPC context derived from the current context, with {@link RequestMetadata} + * accessible by the {@link fromCurrentContext()} method. + * + * <p>The {@link RequestMetadata} is constructed using the provided arguments and the current tool + * version. + */ + public static Context contextWithMetadata( + String buildRequestId, String commandId, String actionId) { + Preconditions.checkNotNull(buildRequestId); + Preconditions.checkNotNull(commandId); + RequestMetadata.Builder metadata = + RequestMetadata.newBuilder() + .setCorrelatedInvocationsId(buildRequestId) + .setToolInvocationId(commandId); + metadata.setActionId(actionId); + metadata.setToolDetails(ToolDetails.newBuilder() + .setToolName("bazel") + .setToolVersion(BlazeVersionInfo.instance().getVersion())) + .build(); + return Context.current().withValue(CONTEXT_KEY, metadata.build()); } /**
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java new file mode 100644 index 0000000..75b46c4 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java
@@ -0,0 +1,236 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.bytestream.ByteStreamProto.WriteRequest; +import com.google.bytestream.ByteStreamProto.WriteResponse; +import com.google.common.io.BaseEncoding; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.devtools.build.lib.buildeventstream.BuildEvent.LocalFile; +import com.google.devtools.build.lib.buildeventstream.BuildEvent.LocalFile.LocalFileType; +import com.google.devtools.build.lib.buildeventstream.PathConverter; +import com.google.devtools.build.lib.clock.JavaClock; +import com.google.devtools.build.lib.remote.ByteStreamUploaderTest.FixedBackoff; +import com.google.devtools.build.lib.remote.ByteStreamUploaderTest.MaybeFailOnceUploadService; +import com.google.devtools.build.lib.remote.Retrier.RetryException; +import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; +import com.google.devtools.build.lib.vfs.DigestHashFunction; +import com.google.devtools.build.lib.vfs.FileSystem; +import com.google.devtools.build.lib.vfs.Path; +import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; +import com.google.devtools.remoteexecution.v1test.Digest; +import io.grpc.Context; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.util.MutableHandlerRegistry; +import java.io.OutputStream; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockitoAnnotations; + +/** Test for {@link ByteStreamBuildEventArtifactUploader}. */ +@RunWith(JUnit4.class) +public class ByteStreamBuildEventArtifactUploaderTest { + + private static final DigestUtil DIGEST_UTIL = new DigestUtil(DigestHashFunction.SHA256); + + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + private static ListeningScheduledExecutorService retryService; + + private Server server; + private ManagedChannel channel; + private Context withEmptyMetadata; + private Context prevContext; + private final FileSystem fs = new InMemoryFileSystem(new JavaClock(), DigestHashFunction.SHA256); + + @BeforeClass + public static void beforeEverything() { + retryService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1)); + } + + @Before + public final void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + String serverName = "Server for " + this.getClass(); + server = + InProcessServerBuilder.forName(serverName) + .fallbackHandlerRegistry(serviceRegistry) + .build() + .start(); + channel = InProcessChannelBuilder.forName(serverName).build(); + withEmptyMetadata = + TracingMetadataUtils.contextWithMetadata( + "none", "none", DIGEST_UTIL.asActionKey(Digest.getDefaultInstance())); + // Needs to be repeated in every test that uses the timeout setting, since the tests run + // on different threads than the setUp. + prevContext = withEmptyMetadata.attach(); + } + + @After + public void tearDown() throws Exception { + // Needs to be repeated in every test that uses the timeout setting, since the tests run + // on different threads than the tearDown. + withEmptyMetadata.detach(prevContext); + + server.shutdownNow(); + server.awaitTermination(); + } + + @AfterClass + public static void afterEverything() { + retryService.shutdownNow(); + } + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void uploadsShouldWork() throws Exception { + int numUploads = 2; + Map<String, byte[]> blobsByHash = new HashMap<>(); + Map<Path, LocalFile> filesToUpload = new HashMap<>(); + Random rand = new Random(); + for (int i = 0; i < numUploads; i++) { + Path file = fs.getPath("/file" + i); + OutputStream out = file.getOutputStream(); + int blobSize = rand.nextInt(100) + 1; + byte[] blob = new byte[blobSize]; + rand.nextBytes(blob); + out.write(blob); + out.close(); + blobsByHash.put(DIGEST_UTIL.compute(file).getHash(), blob); + filesToUpload.put(file, new LocalFile(file, LocalFileType.OUTPUT)); + } + serviceRegistry.addService(new MaybeFailOnceUploadService(blobsByHash)); + + RemoteRetrier retrier = + new RemoteRetrier( + () -> new FixedBackoff(1, 0), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); + ReferenceCountedChannel refCntChannel = new ReferenceCountedChannel(channel); + ByteStreamUploader uploader = + new ByteStreamUploader("instance", refCntChannel, null, 3, retrier); + ByteStreamBuildEventArtifactUploader artifactUploader = + new ByteStreamBuildEventArtifactUploader( + uploader, "localhost", withEmptyMetadata, "instance"); + + PathConverter pathConverter = artifactUploader.upload(filesToUpload).get(); + for (Path file : filesToUpload.keySet()) { + String hash = BaseEncoding.base16().lowerCase().encode(file.getDigest()); + long size = file.getFileSize(); + String conversion = pathConverter.apply(file); + assertThat(conversion) + .isEqualTo("bytestream://localhost/instance/blobs/" + hash + "/" + size); + } + + artifactUploader.shutdown(); + + assertThat(uploader.refCnt()).isEqualTo(0); + assertThat(refCntChannel.isShutdown()).isTrue(); + } + + @Test + public void someUploadsFail() throws Exception { + // Test that if one of multiple file uploads fails, the upload future fails and that the + // error is propagated correctly. + + int numUploads = 10; + Map<String, byte[]> blobsByHash = new HashMap<>(); + Map<Path, LocalFile> filesToUpload = new HashMap<>(); + Random rand = new Random(); + for (int i = 0; i < numUploads; i++) { + Path file = fs.getPath("/file" + i); + OutputStream out = file.getOutputStream(); + int blobSize = rand.nextInt(100) + 1; + byte[] blob = new byte[blobSize]; + rand.nextBytes(blob); + out.write(blob); + out.flush(); + out.close(); + blobsByHash.put(DIGEST_UTIL.compute(file).getHash(), blob); + filesToUpload.put(file, new LocalFile(file, LocalFileType.OUTPUT)); + } + String hashOfBlobThatShouldFail = blobsByHash.keySet().iterator().next(); + serviceRegistry.addService(new MaybeFailOnceUploadService(blobsByHash) { + @Override + public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) { + StreamObserver<WriteRequest> delegate = super.write(response); + return new StreamObserver<WriteRequest>() { + @Override + public void onNext(WriteRequest value) { + if (value.getResourceName().contains(hashOfBlobThatShouldFail)) { + response.onError(Status.CANCELLED.asException()); + } else { + delegate.onNext(value); + } + } + + @Override + public void onError(Throwable t) { + delegate.onError(t); + } + + @Override + public void onCompleted() { + delegate.onCompleted(); + } + }; + } + }); + + RemoteRetrier retrier = + new RemoteRetrier( + () -> new FixedBackoff(1, 0), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); + ReferenceCountedChannel refCntChannel = new ReferenceCountedChannel(channel); + ByteStreamUploader uploader = + new ByteStreamUploader("instance", refCntChannel, null, 3, retrier); + ByteStreamBuildEventArtifactUploader artifactUploader = + new ByteStreamBuildEventArtifactUploader( + uploader, "localhost", withEmptyMetadata, "instance"); + + try { + artifactUploader.upload(filesToUpload).get(); + fail("exception expected."); + } catch (ExecutionException e) { + assertThat(e.getCause()).isInstanceOf(RetryException.class); + assertThat(Status.fromThrowable(e).getCode()).isEqualTo(Status.CANCELLED.getCode()); + } + + artifactUploader.shutdown(); + + assertThat(uploader.refCnt()).isEqualTo(0); + assertThat(refCntChannel.isShutdown()).isTrue(); + } +}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java index 7b2dc7a..06ad391 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java
@@ -34,8 +34,8 @@ import com.google.devtools.remoteexecution.v1test.RequestMetadata; import com.google.protobuf.ByteString; import io.grpc.BindableService; -import io.grpc.Channel; import io.grpc.Context; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Server; import io.grpc.ServerCall; @@ -90,7 +90,7 @@ private static ListeningScheduledExecutorService retryService; private Server server; - private Channel channel; + private ManagedChannel channel; private Context withEmptyMetadata; private Context prevContext; @@ -137,7 +137,8 @@ Context prevContext = withEmptyMetadata.attach(); RemoteRetrier retrier = new RemoteRetrier(() -> mockBackoff, (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); byte[] blob = new byte[CHUNK_SIZE * 2 + 1]; new Random().nextBytes(blob); @@ -193,7 +194,7 @@ } }); - uploader.uploadBlob(chunker); + uploader.uploadBlob(chunker, true); // This test should not have triggered any retries. Mockito.verifyZeroInteractions(mockBackoff); @@ -209,7 +210,8 @@ RemoteRetrier retrier = new RemoteRetrier( () -> new FixedBackoff(1, 0), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); int numUploads = 10; Map<String, byte[]> blobsByHash = new HashMap<>(); @@ -224,70 +226,9 @@ blobsByHash.put(chunker.digest().getHash(), blob); } - Set<String> uploadsFailedOnce = Collections.synchronizedSet(new HashSet<>()); + serviceRegistry.addService(new MaybeFailOnceUploadService(blobsByHash)); - serviceRegistry.addService(new ByteStreamImplBase() { - @Override - public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) { - return new StreamObserver<WriteRequest>() { - - private String digestHash; - private byte[] receivedData; - private long nextOffset; - - @Override - public void onNext(WriteRequest writeRequest) { - if (nextOffset == 0) { - String resourceName = writeRequest.getResourceName(); - assertThat(resourceName).isNotEmpty(); - - String[] components = resourceName.split("/"); - assertThat(components).hasLength(6); - digestHash = components[4]; - assertThat(blobsByHash).containsKey(digestHash); - receivedData = new byte[Integer.parseInt(components[5])]; - } - assertThat(digestHash).isNotNull(); - // An upload for a given blob has a 10% chance to fail once during its lifetime. - // This is to exercise the retry mechanism a bit. - boolean shouldFail = - rand.nextInt(10) == 0 && !uploadsFailedOnce.contains(digestHash); - if (shouldFail) { - uploadsFailedOnce.add(digestHash); - response.onError(Status.INTERNAL.asException()); - return; - } - - ByteString data = writeRequest.getData(); - System.arraycopy( - data.toByteArray(), 0, receivedData, (int) nextOffset, data.size()); - nextOffset += data.size(); - - boolean lastWrite = nextOffset == receivedData.length; - assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite); - } - - @Override - public void onError(Throwable throwable) { - fail("onError should never be called."); - } - - @Override - public void onCompleted() { - byte[] expectedBlob = blobsByHash.get(digestHash); - assertThat(receivedData).isEqualTo(expectedBlob); - - WriteResponse writeResponse = - WriteResponse.newBuilder().setCommittedSize(receivedData.length).build(); - - response.onNext(writeResponse); - response.onCompleted(); - } - }; - } - }); - - uploader.uploadBlobs(builders); + uploader.uploadBlobs(builders, true); blockUntilInternalStateConsistent(uploader); @@ -302,7 +243,8 @@ RemoteRetrier retrier = new RemoteRetrier( () -> new FixedBackoff(5, 0), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); List<String> toUpload = ImmutableList.of("aaaaaaaaaa", "bbbbbbbbbb", "cccccccccc"); List<Chunker> builders = new ArrayList<>(toUpload.size()); @@ -372,7 +314,7 @@ "build-req-id", "command-id", DIGEST_UTIL.asActionKey(chunker.digest())); ctx.call( () -> { - uploads.add(uploader.uploadBlobAsync(chunker)); + uploads.add(uploader.uploadBlobAsync(chunker, true)); return null; }); } @@ -393,7 +335,8 @@ Context prevContext = withEmptyMetadata.attach(); RemoteRetrier retrier = new RemoteRetrier(() -> mockBackoff, (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); byte[] blob = new byte[CHUNK_SIZE * 10]; Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); @@ -435,8 +378,8 @@ } }); - Future<?> upload1 = uploader.uploadBlobAsync(chunker); - Future<?> upload2 = uploader.uploadBlobAsync(chunker); + Future<?> upload1 = uploader.uploadBlobAsync(chunker, true); + Future<?> upload2 = uploader.uploadBlobAsync(chunker, true); blocker.countDown(); @@ -455,7 +398,8 @@ RemoteRetrier retrier = new RemoteRetrier( () -> new FixedBackoff(1, 10), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); byte[] blob = new byte[CHUNK_SIZE]; Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); @@ -469,7 +413,7 @@ }); try { - uploader.uploadBlob(chunker); + uploader.uploadBlob(chunker, true); fail("Should have thrown an exception."); } catch (RetryException e) { assertThat(e.getAttempts()).isEqualTo(2); @@ -485,7 +429,8 @@ RemoteRetrier retrier = new RemoteRetrier( () -> new FixedBackoff(1, 10), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); CountDownLatch cancellations = new CountDownLatch(2); @@ -520,8 +465,8 @@ byte[] blob2 = new byte[CHUNK_SIZE + 1]; Chunker chunker2 = new Chunker(blob2, CHUNK_SIZE, DIGEST_UTIL); - ListenableFuture<Void> f1 = uploader.uploadBlobAsync(chunker1); - ListenableFuture<Void> f2 = uploader.uploadBlobAsync(chunker2); + ListenableFuture<Void> f1 = uploader.uploadBlobAsync(chunker1, true); + ListenableFuture<Void> f2 = uploader.uploadBlobAsync(chunker2, true); assertThat(uploader.uploadsInProgress()).isTrue(); @@ -545,7 +490,8 @@ RemoteRetrier retrier = new RemoteRetrier( () -> new FixedBackoff(1, 10), (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); - ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, channel, null, 3, retrier); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); serviceRegistry.addService(new ByteStreamImplBase() { @Override @@ -564,7 +510,7 @@ byte[] blob = new byte[1]; Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); try { - uploader.uploadBlob(chunker); + uploader.uploadBlob(chunker, true); fail("Should have thrown an exception."); } catch (RetryException e) { assertThat(e).hasCauseThat().isInstanceOf(RejectedExecutionException.class); @@ -579,7 +525,8 @@ RemoteRetrier retrier = new RemoteRetrier(() -> mockBackoff, (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); ByteStreamUploader uploader = - new ByteStreamUploader(/* instanceName */ null, channel, null, 3, retrier); + new ByteStreamUploader(/* instanceName */ null, + new ReferenceCountedChannel(channel), null, 3, retrier); serviceRegistry.addService(new ByteStreamImplBase() { @Override @@ -608,7 +555,7 @@ byte[] blob = new byte[1]; Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); - uploader.uploadBlob(chunker); + uploader.uploadBlob(chunker, true); withEmptyMetadata.detach(prevContext); } @@ -623,7 +570,8 @@ retryService, Retrier.ALLOW_ALL_CALLS); ByteStreamUploader uploader = - new ByteStreamUploader(/* instanceName */ null, channel, null, 3, retrier); + new ByteStreamUploader(/* instanceName */ null, + new ReferenceCountedChannel(channel), null, 3, retrier); AtomicInteger numCalls = new AtomicInteger(); @@ -640,7 +588,7 @@ Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); try { - uploader.uploadBlob(chunker); + uploader.uploadBlob(chunker, true); fail("Should have thrown an exception."); } catch (RetryException e) { assertThat(numCalls.get()).isEqualTo(1); @@ -649,6 +597,67 @@ withEmptyMetadata.detach(prevContext); } + @Test + public void deduplicationOfUploadsShouldWork() throws Exception { + Context prevContext = withEmptyMetadata.attach(); + RemoteRetrier retrier = + new RemoteRetrier(() -> mockBackoff, (e) -> true, retryService, Retrier.ALLOW_ALL_CALLS); + ByteStreamUploader uploader = new ByteStreamUploader(INSTANCE_NAME, + new ReferenceCountedChannel(channel), null, 3, retrier); + + byte[] blob = new byte[CHUNK_SIZE * 2 + 1]; + new Random().nextBytes(blob); + + Chunker chunker = new Chunker(blob, CHUNK_SIZE, DIGEST_UTIL); + + AtomicInteger numUploads = new AtomicInteger(); + serviceRegistry.addService(new ByteStreamImplBase() { + @Override + public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> streamObserver) { + numUploads.incrementAndGet(); + return new StreamObserver<WriteRequest>() { + + long nextOffset = 0; + + @Override + public void onNext(WriteRequest writeRequest) { + nextOffset += writeRequest.getData().size(); + boolean lastWrite = blob.length == nextOffset; + assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite); + } + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + assertThat(nextOffset).isEqualTo(blob.length); + + WriteResponse response = + WriteResponse.newBuilder().setCommittedSize(nextOffset).build(); + streamObserver.onNext(response); + streamObserver.onCompleted(); + } + }; + } + }); + + uploader.uploadBlob(chunker, true); + // This should not trigger an upload. + uploader.uploadBlob(chunker, false); + + assertThat(numUploads.get()).isEqualTo(1); + + // This test should not have triggered any retries. + Mockito.verifyZeroInteractions(mockBackoff); + + blockUntilInternalStateConsistent(uploader); + + withEmptyMetadata.detach(prevContext); + } + private static class NoopStreamObserver implements StreamObserver<WriteRequest> { @Override public void onNext(WriteRequest writeRequest) { @@ -663,7 +672,7 @@ } } - private static class FixedBackoff implements Retrier.Backoff { + static class FixedBackoff implements Retrier.Backoff { private final int maxRetries; private final int delayMillis; @@ -690,6 +699,80 @@ } } + /** + * An byte stream service where an upload for a given blob may or may not fail on the first + * attempt but is guaranteed to succeed on the second try. + */ + static class MaybeFailOnceUploadService extends ByteStreamImplBase { + + private final Map<String, byte[]> blobsByHash; + private final Set<String> uploadsFailedOnce = Collections.synchronizedSet(new HashSet<>()); + private final Random rand = new Random(); + + MaybeFailOnceUploadService(Map<String, byte[]> blobsByHash) { + this.blobsByHash = blobsByHash; + } + + @Override + public StreamObserver<WriteRequest> write(StreamObserver<WriteResponse> response) { + return new StreamObserver<WriteRequest>() { + + private String digestHash; + private byte[] receivedData; + private long nextOffset; + + @Override + public void onNext(WriteRequest writeRequest) { + if (nextOffset == 0) { + String resourceName = writeRequest.getResourceName(); + assertThat(resourceName).isNotEmpty(); + + String[] components = resourceName.split("/"); + assertThat(components).hasLength(6); + digestHash = components[4]; + assertThat(blobsByHash).containsKey(digestHash); + receivedData = new byte[Integer.parseInt(components[5])]; + } + assertThat(digestHash).isNotNull(); + // An upload for a given blob has a 10% chance to fail once during its lifetime. + // This is to exercise the retry mechanism a bit. + boolean shouldFail = + rand.nextInt(10) == 0 && !uploadsFailedOnce.contains(digestHash); + if (shouldFail) { + uploadsFailedOnce.add(digestHash); + response.onError(Status.INTERNAL.asException()); + return; + } + + ByteString data = writeRequest.getData(); + System.arraycopy( + data.toByteArray(), 0, receivedData, (int) nextOffset, data.size()); + nextOffset += data.size(); + + boolean lastWrite = nextOffset == receivedData.length; + assertThat(writeRequest.getFinishWrite()).isEqualTo(lastWrite); + } + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + byte[] expectedBlob = blobsByHash.get(digestHash); + assertThat(receivedData).isEqualTo(expectedBlob); + + WriteResponse writeResponse = + WriteResponse.newBuilder().setCommittedSize(receivedData.length).build(); + + response.onNext(writeResponse); + response.onCompleted(); + } + }; + } + } + private void blockUntilInternalStateConsistent(ByteStreamUploader uploader) throws Exception { // Poll until all upload futures have been removed from the internal hash map. The polling is // necessary, as listeners are executed after Future.get() calls are notified about completion.
diff --git a/src/test/java/com/google/devtools/build/lib/remote/CasPathConverterTest.java b/src/test/java/com/google/devtools/build/lib/remote/CasPathConverterTest.java deleted file mode 100644 index 5c6dcf2..0000000 --- a/src/test/java/com/google/devtools/build/lib/remote/CasPathConverterTest.java +++ /dev/null
@@ -1,79 +0,0 @@ -// Copyright 2017 The Bazel Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package com.google.devtools.build.lib.remote; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.devtools.build.lib.remote.RemoteModule.CasPathConverter; -import com.google.devtools.build.lib.remote.util.DigestUtil; -import com.google.devtools.build.lib.vfs.DigestHashFunction; -import com.google.devtools.build.lib.vfs.FileSystem; -import com.google.devtools.build.lib.vfs.FileSystemUtils; -import com.google.devtools.build.lib.vfs.Path; -import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; -import com.google.devtools.common.options.Options; -import com.google.devtools.common.options.OptionsParser; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link CasPathConverter}. */ -@RunWith(JUnit4.class) -public class CasPathConverterTest { - private final FileSystem fs = new InMemoryFileSystem(); - private final CasPathConverter converter = new CasPathConverter(); - - @Test - public void noOptionsShouldntCrash() { - converter.digestUtil = new DigestUtil(DigestHashFunction.SHA256); - assertThat(converter.apply(fs.getPath("/foo"))).isEqualTo("file:///foo"); - } - - @Test - public void noDigestUtilShouldntCrash() { - converter.options = Options.getDefaults(RemoteOptions.class); - assertThat(converter.apply(fs.getPath("/foo"))).isEqualTo("file:///foo"); - } - - @Test - public void disabledRemote() { - converter.options = Options.getDefaults(RemoteOptions.class); - converter.digestUtil = new DigestUtil(DigestHashFunction.SHA256); - assertThat(converter.apply(fs.getPath("/foo"))).isEqualTo("file:///foo"); - } - - @Test - public void enabledRemoteExecutorNoRemoteInstance() throws Exception { - OptionsParser parser = OptionsParser.newOptionsParser(RemoteOptions.class); - parser.parse("--remote_cache=machine"); - converter.options = parser.getOptions(RemoteOptions.class); - converter.digestUtil = new DigestUtil(DigestHashFunction.SHA256); - Path path = fs.getPath("/foo"); - FileSystemUtils.writeContentAsLatin1(path, "foobar"); - assertThat(converter.apply(fs.getPath("/foo"))) - .isEqualTo("bytestream://machine/blobs/3858f62230ac3c915f300c664312c63f/6"); - } - - @Test - public void enabledRemoteExecutorWithRemoteInstance() throws Exception { - OptionsParser parser = OptionsParser.newOptionsParser(RemoteOptions.class); - parser.parse("--remote_cache=machine", "--remote_instance_name=projects/bazel"); - converter.options = parser.getOptions(RemoteOptions.class); - converter.digestUtil = new DigestUtil(DigestHashFunction.SHA256); - Path path = fs.getPath("/foo"); - FileSystemUtils.writeContentAsLatin1(path, "foobar"); - assertThat(converter.apply(fs.getPath("/foo"))) - .isEqualTo("bytestream://machine/projects/bazel/blobs/3858f62230ac3c915f300c664312c63f/6"); - } -}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteCacheTest.java index 4fc5869..3c61fd1 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteCacheTest.java
@@ -64,7 +64,6 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; -import io.grpc.ClientInterceptors; import io.grpc.Context; import io.grpc.MethodDescriptor; import io.grpc.Server; @@ -177,9 +176,9 @@ Scratch scratch = new Scratch(); scratch.file(authTlsOptions.googleCredentials, new JacksonFactory().toString(json)); - CallCredentials creds = null; + CallCredentials creds; try (InputStream in = scratch.resolve(authTlsOptions.googleCredentials).getInputStream()) { - GoogleAuthUtils.newCallCredentials(in, authTlsOptions.googleAuthScopes); + creds = GoogleAuthUtils.newCallCredentials(in, authTlsOptions.googleAuthScopes); } RemoteOptions remoteOptions = Options.getDefaults(RemoteOptions.class); RemoteRetrier retrier = @@ -188,14 +187,18 @@ RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService, Retrier.ALLOW_ALL_CALLS); - return new GrpcRemoteCache( - ClientInterceptors.intercept( - InProcessChannelBuilder.forName(fakeServerName).directExecutor().build(), - ImmutableList.of(new CallCredentialsInterceptor(creds))), + ReferenceCountedChannel channel = + new ReferenceCountedChannel(InProcessChannelBuilder.forName(fakeServerName).directExecutor() + .intercept(new CallCredentialsInterceptor(creds)).build()); + ByteStreamUploader uploader = + new ByteStreamUploader(remoteOptions.remoteInstanceName, channel.retain(), creds, + remoteOptions.remoteTimeout, retrier); + return new GrpcRemoteCache(channel.retain(), creds, remoteOptions, retrier, - DIGEST_UTIL); + DIGEST_UTIL, + uploader); } @Test
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java index 849259f..93d1cdd 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcRemoteExecutionClientTest.java
@@ -80,7 +80,6 @@ import com.google.watcher.v1.WatcherGrpc.WatcherImplBase; import io.grpc.BindableService; import io.grpc.CallCredentials; -import io.grpc.Channel; import io.grpc.Metadata; import io.grpc.Server; import io.grpc.ServerCall; @@ -258,13 +257,17 @@ RemoteRetrier.RETRIABLE_GRPC_ERRORS, retryService, Retrier.ALLOW_ALL_CALLS); - Channel channel = InProcessChannelBuilder.forName(fakeServerName).directExecutor().build(); + ReferenceCountedChannel channel = + new ReferenceCountedChannel(InProcessChannelBuilder.forName(fakeServerName).directExecutor().build()); GrpcRemoteExecutor executor = - new GrpcRemoteExecutor(channel, null, remoteOptions.remoteTimeout, retrier); + new GrpcRemoteExecutor(channel.retain(), null, remoteOptions.remoteTimeout, retrier); CallCredentials creds = GoogleAuthUtils.newCallCredentials(Options.getDefaults(AuthAndTLSOptions.class)); + ByteStreamUploader uploader = + new ByteStreamUploader(remoteOptions.remoteInstanceName, channel.retain(), creds, + remoteOptions.remoteTimeout, retrier); GrpcRemoteCache remoteCache = - new GrpcRemoteCache(channel, creds, remoteOptions, retrier, DIGEST_UTIL); + new GrpcRemoteCache(channel.retain(), creds, remoteOptions, retrier, DIGEST_UTIL, uploader); client = new RemoteSpawnRunner( execRoot, @@ -281,6 +284,7 @@ DIGEST_UTIL, logDir); inputDigest = fakeFileCache.createScratchInput(simpleSpawn.getInputFiles().get(0), "xyz"); + channel.release(); } @After