remote: Proactively close the ZstdInputStream in ZstdDecompressingOutputStream.
ZstdInputStream hangs onto some native memory, which should be released as soon as ZstdDecompressingOutputStream is done being used rather than when the finalizer runs.
Closes #15061.
PiperOrigin-RevId: 438521302
diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index 7f8d359..06fb72d 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -97,7 +97,6 @@
"//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
"//src/main/java/com/google/devtools/common/options",
"//src/main/protobuf:failure_details_java_proto",
- "//third_party:apache_commons_compress",
"//third_party:auth",
"//third_party:caffeine",
"//third_party:flogger",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
index 293562d..5dd7dc0 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java
@@ -37,6 +37,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
+import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
@@ -67,10 +68,8 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.Nullable;
-import org.apache.commons.compress.utils.CountingOutputStream;
/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
@ThreadSafe
@@ -303,7 +302,7 @@
public ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context, Digest digest, OutputStream out) {
if (digest.getSizeBytes() == 0) {
- return Futures.immediateFuture(null);
+ return Futures.immediateVoidFuture();
}
@Nullable Supplier<Digest> digestSupplier = null;
@@ -313,18 +312,7 @@
out = digestOut;
}
- CountingOutputStream outputStream;
- if (options.cacheCompression) {
- try {
- outputStream = new ZstdDecompressingOutputStream(out);
- } catch (IOException e) {
- return Futures.immediateFailedFuture(e);
- }
- } else {
- outputStream = new CountingOutputStream(out);
- }
-
- return downloadBlob(context, digest, outputStream, digestSupplier);
+ return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier);
}
private ListenableFuture<Void> downloadBlob(
@@ -332,7 +320,6 @@
Digest digest,
CountingOutputStream out,
@Nullable Supplier<Digest> digestSupplier) {
- AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
ListenableFuture<Long> downloadFuture =
Utils.refreshIfUnauthenticatedAsync(
@@ -343,7 +330,6 @@
channel ->
requestRead(
context,
- offset,
progressiveBackoff,
digest,
out,
@@ -370,20 +356,25 @@
private ListenableFuture<Long> requestRead(
RemoteActionExecutionContext context,
- AtomicLong offset,
ProgressiveBackoff progressiveBackoff,
Digest digest,
- CountingOutputStream out,
+ CountingOutputStream rawOut,
@Nullable Supplier<Digest> digestSupplier,
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
SettableFuture<Long> future = SettableFuture.create();
+ OutputStream out;
+ try {
+ out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
+ } catch (IOException e) {
+ return Futures.immediateFailedFuture(e);
+ }
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
- .setReadOffset(offset.get())
+ .setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {
@@ -392,7 +383,6 @@
ByteString data = readResponse.getData();
try {
data.writeTo(out);
- offset.set(out.getBytesWritten());
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
@@ -403,7 +393,7 @@
@Override
public void onError(Throwable t) {
- if (offset.get() == digest.getSizeBytes()) {
+ if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an error at
// the end of the stream.
logger.atInfo().withCause(t).log(
@@ -411,6 +401,7 @@
onCompleted();
return;
}
+ releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
@@ -426,12 +417,24 @@
Utils.verifyBlobContents(digest, digestSupplier.get());
}
out.flush();
- future.set(offset.get());
+ future.set(rawOut.getCount());
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
+ } finally {
+ releaseOut();
+ }
+ }
+
+ private void releaseOut() {
+ if (out instanceof ZstdDecompressingOutputStream) {
+ try {
+ ((ZstdDecompressingOutputStream) out).closeShallow();
+ } catch (IOException e) {
+ logger.atWarning().withCause(e).log("failed to cleanly close output stream");
+ }
}
}
});
diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD
index 6108cdd..75691a6 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD
@@ -16,7 +16,6 @@
name = "zstd",
srcs = glob(["*.java"]),
deps = [
- "//third_party:apache_commons_compress",
"//third_party:guava",
"//third_party/protobuf:protobuf_java",
"@zstd-jni",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java
index ad1c333..9fdb6ae 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java
@@ -13,35 +13,35 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.zstd;
-import com.github.luben.zstd.ZstdInputStream;
+import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
-import org.apache.commons.compress.utils.CountingOutputStream;
-/** A {@link CountingOutputStream} that use zstd to decompress the content. */
-public class ZstdDecompressingOutputStream extends CountingOutputStream {
+/** An {@link OutputStream} that use zstd to decompress the content. */
+public final class ZstdDecompressingOutputStream extends OutputStream {
+ private final OutputStream out;
private ByteArrayInputStream inner;
- private final ZstdInputStream zis;
+ private final ZstdInputStreamNoFinalizer zis;
public ZstdDecompressingOutputStream(OutputStream out) throws IOException {
- super(out);
+ this.out = out;
zis =
- new ZstdInputStream(
- new InputStream() {
- @Override
- public int read() {
- return inner.read();
- }
+ new ZstdInputStreamNoFinalizer(
+ new InputStream() {
+ @Override
+ public int read() {
+ return inner.read();
+ }
- @Override
- public int read(byte[] b, int off, int len) {
- return inner.read(b, off, len);
- }
- });
- zis.setContinuous(true);
+ @Override
+ public int read(byte[] b, int off, int len) {
+ return inner.read(b, off, len);
+ }
+ })
+ .setContinuous(true);
}
@Override
@@ -58,6 +58,19 @@
public void write(byte[] b, int off, int len) throws IOException {
inner = new ByteArrayInputStream(b, off, len);
byte[] data = ByteString.readFrom(zis).toByteArray();
- super.write(data, 0, data.length);
+ out.write(data, 0, data.length);
+ }
+
+ @Override
+ public void close() throws IOException {
+ closeShallow();
+ out.close();
+ }
+
+ /**
+ * Free resources related to decompression without closing the underlying {@link OutputStream}.
+ */
+ public void closeShallow() throws IOException {
+ zis.close();
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java
index 51effa0..80d55ed 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java
@@ -15,14 +15,12 @@
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
-import static org.mockito.ArgumentMatchers.any;
import build.bazel.remote.execution.v2.Digest;
import com.github.luben.zstd.Zstd;
import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
import com.google.bytestream.ByteStreamProto.ReadRequest;
import com.google.bytestream.ByteStreamProto.ReadResponse;
-import com.google.devtools.build.lib.remote.Retrier.Backoff;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.common.options.Options;
import com.google.protobuf.ByteString;
@@ -31,7 +29,6 @@
import java.io.IOException;
import java.util.Arrays;
import org.junit.Test;
-import org.mockito.Mockito;
/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */
public class GrpcCacheClientTestExtra extends GrpcCacheClientTest {
@@ -39,30 +36,43 @@
@Test
public void compressedDownloadBlobIsRetriedWithProgress()
throws IOException, InterruptedException {
- Backoff mockBackoff = Mockito.mock(Backoff.class);
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
options.cacheCompression = true;
- final GrpcCacheClient client = newClient(options, () -> mockBackoff);
+ final GrpcCacheClient client = newClient(options);
final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
- ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8)));
+ ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8)));
+ ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8)));
+ ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8)));
serviceRegistry.addService(
new ByteStreamImplBase() {
+ private boolean first = true;
+
@Override
public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
assertThat(request.getResourceName().contains(digest.getHash())).isTrue();
- int off = (int) request.getReadOffset();
- // Zstd header size is 9 bytes
- ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off);
- responseObserver.onNext(ReadResponse.newBuilder().setData(data).build());
- if (off == 0) {
+ if (first) {
+ first = false;
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
- } else {
- responseObserver.onCompleted();
+ return;
}
+ switch (Math.toIntExact(request.getReadOffset())) {
+ case 0:
+ responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build());
+ break;
+ case 3:
+ responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build());
+ break;
+ case 6:
+ responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build());
+ responseObserver.onCompleted();
+ return;
+ default:
+ throw new IllegalStateException("unexpected offset " + request.getReadOffset());
+ }
+ responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
}
});
assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg");
- Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
}
@Test
diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java
index 22cba85..62352dd 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java
@@ -63,7 +63,6 @@
for (byte b : compressed.toByteArray()) {
zdos.write(b);
zdos.flush();
- assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length);
}
assertThat(decompressed.toByteArray()).isEqualTo(data);
}