Fix bug in NestedSetStore where racing deserializations could create multiple futures for the same fingerprint and add a test showing that racing serializations can result in duplicate writes.
PiperOrigin-RevId: 200860099
diff --git a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
index 173420a..9aa3a99 100644
--- a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
+++ b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
@@ -22,6 +22,7 @@
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
import com.google.devtools.build.lib.skyframe.serialization.SerializationConstants;
import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
@@ -68,11 +69,17 @@
/**
* Associates a fingerprint with the serialized representation of some NestedSet contents.
* Returns a future that completes when the write completes.
+ *
+ * <p>It is the responsibility of the caller to deduplicate {@code put} calls, to avoid multiple
+ * writes of the same fingerprint.
*/
ListenableFuture<Void> put(ByteString fingerprint, byte[] serializedBytes) throws IOException;
/**
* Retrieves the serialized bytes for the NestedSet contents associated with this fingerprint.
+ *
+ * <p>It is the responsibility of the caller to deduplicate {@code get} calls, to avoid multiple
+ * fetches of the same fingerprint.
*/
ListenableFuture<byte[]> get(ByteString fingerprint) throws IOException;
}
@@ -112,43 +119,72 @@
.build();
/**
- * Returns the NestedSet contents associated with the given fingerprint. Returns null if the
- * fingerprint is not known.
+ * Returns a {@link ListenableFuture} for NestedSet contents associated with the given
+ * fingerprint if there was already one. Otherwise associates {@code future} with {@code
+ * fingerprint} and returns null.
+ *
+ * <p>Since the associated future is used as the basis for equality comparisons for deserialized
+ * nested sets, it is critical that multiple calls with the same fingerprint don't override the
+ * association.
*/
+ @VisibleForTesting
@Nullable
- public ListenableFuture<Object[]> contentsForFingerprint(ByteString fingerprint) {
- return fingerprintToContents.getIfPresent(fingerprint);
+ ListenableFuture<Object[]> putIfAbsent(
+ ByteString fingerprint, ListenableFuture<Object[]> future) {
+ ListenableFuture<Object[]> result;
+ // Guava's Cache doesn't have a #putIfAbsent method, so we emulate it here.
+ try {
+ result = fingerprintToContents.get(fingerprint, () -> future);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException(e);
+ }
+ if (result.equals(future)) {
+ // This is the first request of this fingerprint. We should put it.
+ putAsync(fingerprint, future);
+ return null;
+ }
+ return result;
}
/**
* Retrieves the fingerprint associated with the given NestedSet contents, or null if the given
* contents are not known.
*/
+ @VisibleForTesting
@Nullable
- public FingerprintComputationResult fingerprintForContents(Object[] contents) {
+ FingerprintComputationResult fingerprintForContents(Object[] contents) {
return contentsToFingerprint.getIfPresent(contents);
}
- /** Associates the provided fingerprint and NestedSet contents. */
- public void put(
- FingerprintComputationResult fingerprintComputationResult,
- ListenableFuture<Object[]> contents) {
- contents.addListener(
+ /**
+ * Associates the provided {@code fingerprint} and contents of the future, when it completes.
+ *
+ * <p>There may be a race between this call and calls to {@link #put}. Those races are benign,
+ * since the fingerprint should be the same regardless. We may pessimistically end up having a
+ * future to wait on for serialization that isn't actually necessary, but that isn't a big
+ * concern.
+ */
+ private void putAsync(ByteString fingerprint, ListenableFuture<Object[]> futureContents) {
+ futureContents.addListener(
() -> {
+ // There may already be an entry here, but it's better to put a fingerprint result with
+ // an immediate future, since then later readers won't need to block unnecessarily. It
+ // would be nice to sanity check the old value, but Cache#put doesn't provide it to us.
try {
- contentsToFingerprint.put(Futures.getDone(contents), fingerprintComputationResult);
+ contentsToFingerprint.put(
+ Futures.getDone(futureContents),
+ FingerprintComputationResult.create(fingerprint, Futures.immediateFuture(null)));
+
} catch (ExecutionException e) {
throw new AssertionError(
- "Expected write for "
- + fingerprintComputationResult.fingerprint()
- + " to be complete",
- e.getCause());
+ "Expected write for " + fingerprint + " to be complete", e.getCause());
}
},
MoreExecutors.directExecutor());
- fingerprintToContents.put(fingerprintComputationResult.fingerprint(), contents);
}
+ // TODO(janakr): Currently, racing threads can overwrite each other's
+ // fingerprintComputationResult, leading to confusion and potential performance drag. Fix this.
public void put(FingerprintComputationResult fingerprintComputationResult, Object[] contents) {
contentsToFingerprint.put(contents, fingerprintComputationResult);
fingerprintToContents.put(
@@ -157,9 +193,8 @@
}
/** The result of a fingerprint computation, including the status of its storage. */
- @VisibleForTesting
@AutoValue
- public abstract static class FingerprintComputationResult {
+ abstract static class FingerprintComputationResult {
static FingerprintComputationResult create(
ByteString fingerprint, ListenableFuture<Void> writeStatus) {
return new AutoValue_NestedSetStore_FingerprintComputationResult(fingerprint, writeStatus);
@@ -168,7 +203,7 @@
abstract ByteString fingerprint();
@VisibleForTesting
- public abstract ListenableFuture<Void> writeStatus();
+ abstract ListenableFuture<Void> writeStatus();
}
private final NestedSetCache nestedSetCache;
@@ -176,6 +211,7 @@
private final Executor executor;
/** Creates a NestedSetStore with the provided {@link NestedSetStorageEndpoint} as a backend. */
+ @VisibleForTesting
public NestedSetStore(NestedSetStorageEndpoint nestedSetStorageEndpoint) {
this(nestedSetStorageEndpoint, new NestedSetCache(), MoreExecutors.directExecutor());
}
@@ -189,7 +225,7 @@
}
@VisibleForTesting
- public NestedSetStore(
+ NestedSetStore(
NestedSetStorageEndpoint nestedSetStorageEndpoint,
NestedSetCache nestedSetCache,
Executor executor) {
@@ -208,9 +244,16 @@
* SerializationContext}, while also associating the contents with the computed fingerprint in the
* store. Recursively does the same for all transitive members (i.e. Object[] members) of the
* provided contents.
+ *
+ * <p>We wish to serialize each nested set only once. However, this is not currently enforced, due
+ * to the check-then-act race below, where we check nestedSetCache and then, significantly later,
+ * insert a result into the cache. This is a bug, but since any thread that redoes unnecessary
+ * work will return the {@link FingerprintComputationResult} containing its own futures, the
+ * serialization work that must wait on remote storage writes to complete will wait on the correct
+ * futures. Thus it is a performance bug, not a correctness bug.
*/
- @VisibleForTesting
- public FingerprintComputationResult computeFingerprintAndStore(
+ // TODO(janakr): fix this, if for no other reason than to make the semantics cleaner.
+ FingerprintComputationResult computeFingerprintAndStore(
Object[] contents, SerializationContext serializationContext)
throws SerializationException, IOException {
FingerprintComputationResult priorFingerprint = nestedSetCache.fingerprintForContents(contents);
@@ -270,15 +313,23 @@
return fingerprintComputationResult;
}
- /** Retrieves and deserializes the NestedSet contents associated with the given fingerprint. */
- public ListenableFuture<Object[]> getContentsAndDeserialize(
+ /**
+ * Retrieves and deserializes the NestedSet contents associated with the given fingerprint.
+ *
+ * <p>We wish to only do one deserialization per fingerprint. This is enforced by the {@link
+ * #nestedSetCache}, which is responsible for returning the canonical future that will contain the
+ * results of the deserialization. If that future is not owned by the current call of this method,
+ * it doesn't have to do anything further.
+ */
+ ListenableFuture<Object[]> getContentsAndDeserialize(
ByteString fingerprint, DeserializationContext deserializationContext) throws IOException {
- ListenableFuture<Object[]> contents = nestedSetCache.contentsForFingerprint(fingerprint);
+ SettableFuture<Object[]> future = SettableFuture.create();
+ ListenableFuture<Object[]> contents = nestedSetCache.putIfAbsent(fingerprint, future);
if (contents != null) {
return contents;
}
ListenableFuture<byte[]> retrieved = nestedSetStorageEndpoint.get(fingerprint);
- ListenableFuture<Object[]> result =
+ future.setFuture(
Futures.transformAsync(
retrieved,
bytes -> {
@@ -314,11 +365,7 @@
},
executor);
},
- executor);
-
- FingerprintComputationResult fingerprintComputationResult =
- FingerprintComputationResult.create(fingerprint, Futures.immediateFuture(null));
- nestedSetCache.put(fingerprintComputationResult, result);
- return result;
+ executor));
+ return future;
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
index 9341084..2ff9206 100644
--- a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
@@ -14,6 +14,7 @@
package com.google.devtools.build.lib.collect.nestedset;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.times;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
@@ -24,11 +25,17 @@
import com.google.devtools.build.lib.collect.nestedset.NestedSetStore.NestedSetCache;
import com.google.devtools.build.lib.collect.nestedset.NestedSetStore.NestedSetStorageEndpoint;
import com.google.devtools.build.lib.skyframe.serialization.AutoRegistry;
+import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs;
import com.google.devtools.build.lib.skyframe.serialization.SerializationConstants;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
import com.google.devtools.build.lib.skyframe.serialization.SerializationResult;
import com.google.protobuf.ByteString;
+import java.io.IOException;
import java.nio.charset.Charset;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -222,6 +229,8 @@
Mockito.doReturn(subset2Future)
.when(nestedSetStorageEndpoint)
.get(fingerprintCaptor.getAllValues().get(1));
+ Mockito.when(emptyNestedSetCache.putIfAbsent(Mockito.any(), Mockito.any()))
+ .thenAnswer(invocation -> null);
ListenableFuture<Object[]> deserializationFuture =
nestedSetStore.getContentsAndDeserialize(
@@ -236,4 +245,96 @@
subset2Future.set(ByteString.copyFrom("mock bytes", Charset.defaultCharset()).toByteArray());
assertThat(deserializationFuture.isDone()).isTrue();
}
+
+ @Test
+ public void racingDeserialization() throws Exception {
+ NestedSetStorageEndpoint nestedSetStorageEndpoint =
+ Mockito.mock(NestedSetStorageEndpoint.class);
+ NestedSetCache nestedSetCache = Mockito.spy(new NestedSetCache());
+ NestedSetStore nestedSetStore =
+ new NestedSetStore(
+ nestedSetStorageEndpoint, nestedSetCache, MoreExecutors.directExecutor());
+ DeserializationContext deserializationContext = Mockito.mock(DeserializationContext.class);
+ ByteString fingerprint = ByteString.copyFromUtf8("fingerprint");
+ // Future never completes, so we don't have to exercise that code in NestedSetStore.
+ SettableFuture<byte[]> storageFuture = SettableFuture.create();
+ Mockito.when(nestedSetStorageEndpoint.get(fingerprint)).thenReturn(storageFuture);
+ CountDownLatch fingerprintRequested = new CountDownLatch(2);
+ Mockito.doAnswer(
+ invocation -> {
+ fingerprintRequested.countDown();
+ @SuppressWarnings("unchecked")
+ ListenableFuture<Object[]> result =
+ (ListenableFuture<Object[]>) invocation.callRealMethod();
+ fingerprintRequested.await();
+ return result;
+ })
+ .when(nestedSetCache)
+ .putIfAbsent(Mockito.eq(fingerprint), Mockito.any());
+ AtomicReference<ListenableFuture<Object[]>> asyncResult = new AtomicReference<>();
+ Thread asyncThread =
+ new Thread(
+ () -> {
+ try {
+ asyncResult.set(
+ nestedSetStore.getContentsAndDeserialize(fingerprint, deserializationContext));
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ });
+ asyncThread.start();
+ ListenableFuture<Object[]> result =
+ nestedSetStore.getContentsAndDeserialize(fingerprint, deserializationContext);
+ asyncThread.join();
+ Mockito.verify(nestedSetStorageEndpoint, times(1)).get(Mockito.eq(fingerprint));
+ assertThat(result).isSameAs(asyncResult.get());
+ assertThat(result.isDone()).isFalse();
+ }
+
+ @Test
+ public void bugInRacingSerialization() throws Exception {
+ NestedSetStorageEndpoint nestedSetStorageEndpoint =
+ Mockito.mock(NestedSetStorageEndpoint.class);
+ NestedSetCache nestedSetCache = Mockito.spy(new NestedSetCache());
+ NestedSetStore nestedSetStore =
+ new NestedSetStore(
+ nestedSetStorageEndpoint, nestedSetCache, MoreExecutors.directExecutor());
+ SerializationContext serializationContext = Mockito.mock(SerializationContext.class);
+ Object[] contents = {new Object()};
+ Mockito.when(serializationContext.getNewMemoizingContext()).thenReturn(serializationContext);
+ Mockito.when(nestedSetStorageEndpoint.put(Mockito.any(), Mockito.any()))
+ .thenAnswer(invocation -> SettableFuture.create());
+ CountDownLatch fingerprintRequested = new CountDownLatch(2);
+ Mockito.doAnswer(
+ invocation -> {
+ fingerprintRequested.countDown();
+ NestedSetStore.FingerprintComputationResult result =
+ (NestedSetStore.FingerprintComputationResult) invocation.callRealMethod();
+ assertThat(result).isNull();
+ fingerprintRequested.await();
+ return null;
+ })
+ .when(nestedSetCache)
+ .fingerprintForContents(contents);
+ AtomicReference<NestedSetStore.FingerprintComputationResult> asyncResult =
+ new AtomicReference<>();
+ Thread asyncThread =
+ new Thread(
+ () -> {
+ try {
+ asyncResult.set(
+ nestedSetStore.computeFingerprintAndStore(contents, serializationContext));
+ } catch (IOException | SerializationException e) {
+ throw new IllegalStateException(e);
+ }
+ });
+ asyncThread.start();
+ NestedSetStore.FingerprintComputationResult result =
+ nestedSetStore.computeFingerprintAndStore(contents, serializationContext);
+ asyncThread.join();
+ // TODO(janakr): This should be one fetch, but we currently do two.
+ Mockito.verify(nestedSetStorageEndpoint, times(2)).put(Mockito.any(), Mockito.any());
+ // TODO(janakr): These should be the same element.
+ assertThat(result).isNotEqualTo(asyncResult.get());
+ }
}