Fix bug in NestedSetStore where racing deserializations could create multiple futures for the same fingerprint and add a test showing that racing serializations can result in duplicate writes.

PiperOrigin-RevId: 200860099
diff --git a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
index 173420a..9aa3a99 100644
--- a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
+++ b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
@@ -22,6 +22,7 @@
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
 import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationConstants;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
@@ -68,11 +69,17 @@
     /**
      * Associates a fingerprint with the serialized representation of some NestedSet contents.
      * Returns a future that completes when the write completes.
+     *
+     * <p>It is the responsibility of the caller to deduplicate {@code put} calls, to avoid multiple
+     * writes of the same fingerprint.
      */
     ListenableFuture<Void> put(ByteString fingerprint, byte[] serializedBytes) throws IOException;
 
     /**
      * Retrieves the serialized bytes for the NestedSet contents associated with this fingerprint.
+     *
+     * <p>It is the responsibility of the caller to deduplicate {@code get} calls, to avoid multiple
+     * fetches of the same fingerprint.
      */
     ListenableFuture<byte[]> get(ByteString fingerprint) throws IOException;
   }
@@ -112,43 +119,72 @@
             .build();
 
     /**
-     * Returns the NestedSet contents associated with the given fingerprint. Returns null if the
-     * fingerprint is not known.
+     * Returns a {@link ListenableFuture} for NestedSet contents associated with the given
+     * fingerprint if there was already one. Otherwise associates {@code future} with {@code
+     * fingerprint} and returns null.
+     *
+     * <p>Since the associated future is used as the basis for equality comparisons for deserialized
+     * nested sets, it is critical that multiple calls with the same fingerprint don't override the
+     * association.
      */
+    @VisibleForTesting
     @Nullable
-    public ListenableFuture<Object[]> contentsForFingerprint(ByteString fingerprint) {
-      return fingerprintToContents.getIfPresent(fingerprint);
+    ListenableFuture<Object[]> putIfAbsent(
+        ByteString fingerprint, ListenableFuture<Object[]> future) {
+      ListenableFuture<Object[]> result;
+      // Guava's Cache doesn't have a #putIfAbsent method, so we emulate it here.
+      try {
+        result = fingerprintToContents.get(fingerprint, () -> future);
+      } catch (ExecutionException e) {
+        throw new IllegalStateException(e);
+      }
+      if (result.equals(future)) {
+        // This is the first request of this fingerprint. We should put it.
+        putAsync(fingerprint, future);
+        return null;
+      }
+      return result;
     }
 
     /**
      * Retrieves the fingerprint associated with the given NestedSet contents, or null if the given
      * contents are not known.
      */
+    @VisibleForTesting
     @Nullable
-    public FingerprintComputationResult fingerprintForContents(Object[] contents) {
+    FingerprintComputationResult fingerprintForContents(Object[] contents) {
       return contentsToFingerprint.getIfPresent(contents);
     }
 
-    /** Associates the provided fingerprint and NestedSet contents. */
-    public void put(
-        FingerprintComputationResult fingerprintComputationResult,
-        ListenableFuture<Object[]> contents) {
-      contents.addListener(
+    /**
+     * Associates the provided {@code fingerprint} and contents of the future, when it completes.
+     *
+     * <p>There may be a race between this call and calls to {@link #put}. Those races are benign,
+     * since the fingerprint should be the same regardless. We may pessimistically end up having a
+     * future to wait on for serialization that isn't actually necessary, but that isn't a big
+     * concern.
+     */
+    private void putAsync(ByteString fingerprint, ListenableFuture<Object[]> futureContents) {
+      futureContents.addListener(
           () -> {
+            // There may already be an entry here, but it's better to put a fingerprint result with
+            // an immediate future, since then later readers won't need to block unnecessarily. It
+            // would be nice to sanity check the old value, but Cache#put doesn't provide it to us.
             try {
-              contentsToFingerprint.put(Futures.getDone(contents), fingerprintComputationResult);
+              contentsToFingerprint.put(
+                  Futures.getDone(futureContents),
+                  FingerprintComputationResult.create(fingerprint, Futures.immediateFuture(null)));
+
             } catch (ExecutionException e) {
               throw new AssertionError(
-                  "Expected write for "
-                      + fingerprintComputationResult.fingerprint()
-                      + " to be complete",
-                  e.getCause());
+                  "Expected write for " + fingerprint + " to be complete", e.getCause());
             }
           },
           MoreExecutors.directExecutor());
-      fingerprintToContents.put(fingerprintComputationResult.fingerprint(), contents);
     }
 
+    // TODO(janakr): Currently, racing threads can overwrite each other's
+    // fingerprintComputationResult, leading to confusion and potential performance drag. Fix this.
     public void put(FingerprintComputationResult fingerprintComputationResult, Object[] contents) {
       contentsToFingerprint.put(contents, fingerprintComputationResult);
       fingerprintToContents.put(
@@ -157,9 +193,8 @@
   }
 
   /** The result of a fingerprint computation, including the status of its storage. */
-  @VisibleForTesting
   @AutoValue
-  public abstract static class FingerprintComputationResult {
+  abstract static class FingerprintComputationResult {
     static FingerprintComputationResult create(
         ByteString fingerprint, ListenableFuture<Void> writeStatus) {
       return new AutoValue_NestedSetStore_FingerprintComputationResult(fingerprint, writeStatus);
@@ -168,7 +203,7 @@
     abstract ByteString fingerprint();
 
     @VisibleForTesting
-    public abstract ListenableFuture<Void> writeStatus();
+    abstract ListenableFuture<Void> writeStatus();
   }
 
   private final NestedSetCache nestedSetCache;
@@ -176,6 +211,7 @@
   private final Executor executor;
 
   /** Creates a NestedSetStore with the provided {@link NestedSetStorageEndpoint} as a backend. */
+  @VisibleForTesting
   public NestedSetStore(NestedSetStorageEndpoint nestedSetStorageEndpoint) {
     this(nestedSetStorageEndpoint, new NestedSetCache(), MoreExecutors.directExecutor());
   }
@@ -189,7 +225,7 @@
   }
 
   @VisibleForTesting
-  public NestedSetStore(
+  NestedSetStore(
       NestedSetStorageEndpoint nestedSetStorageEndpoint,
       NestedSetCache nestedSetCache,
       Executor executor) {
@@ -208,9 +244,16 @@
    * SerializationContext}, while also associating the contents with the computed fingerprint in the
    * store. Recursively does the same for all transitive members (i.e. Object[] members) of the
    * provided contents.
+   *
+   * <p>We wish to serialize each nested set only once. However, this is not currently enforced, due
+   * to the check-then-act race below, where we check nestedSetCache and then, significantly later,
+   * insert a result into the cache. This is a bug, but since any thread that redoes unnecessary
+   * work will return the {@link FingerprintComputationResult} containing its own futures, the
+   * serialization work that must wait on remote storage writes to complete will wait on the correct
+   * futures. Thus it is a performance bug, not a correctness bug.
    */
-  @VisibleForTesting
-  public FingerprintComputationResult computeFingerprintAndStore(
+  // TODO(janakr): fix this, if for no other reason than to make the semantics cleaner.
+  FingerprintComputationResult computeFingerprintAndStore(
       Object[] contents, SerializationContext serializationContext)
       throws SerializationException, IOException {
     FingerprintComputationResult priorFingerprint = nestedSetCache.fingerprintForContents(contents);
@@ -270,15 +313,23 @@
     return fingerprintComputationResult;
   }
 
-  /** Retrieves and deserializes the NestedSet contents associated with the given fingerprint. */
-  public ListenableFuture<Object[]> getContentsAndDeserialize(
+  /**
+   * Retrieves and deserializes the NestedSet contents associated with the given fingerprint.
+   *
+   * <p>We wish to only do one deserialization per fingerprint. This is enforced by the {@link
+   * #nestedSetCache}, which is responsible for returning the canonical future that will contain the
+   * results of the deserialization. If that future is not owned by the current call of this method,
+   * it doesn't have to do anything further.
+   */
+  ListenableFuture<Object[]> getContentsAndDeserialize(
       ByteString fingerprint, DeserializationContext deserializationContext) throws IOException {
-    ListenableFuture<Object[]> contents = nestedSetCache.contentsForFingerprint(fingerprint);
+    SettableFuture<Object[]> future = SettableFuture.create();
+    ListenableFuture<Object[]> contents = nestedSetCache.putIfAbsent(fingerprint, future);
     if (contents != null) {
       return contents;
     }
     ListenableFuture<byte[]> retrieved = nestedSetStorageEndpoint.get(fingerprint);
-    ListenableFuture<Object[]> result =
+    future.setFuture(
         Futures.transformAsync(
             retrieved,
             bytes -> {
@@ -314,11 +365,7 @@
                       },
                       executor);
             },
-            executor);
-
-    FingerprintComputationResult fingerprintComputationResult =
-        FingerprintComputationResult.create(fingerprint, Futures.immediateFuture(null));
-    nestedSetCache.put(fingerprintComputationResult, result);
-    return result;
+            executor));
+    return future;
   }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
index 9341084..2ff9206 100644
--- a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
@@ -14,6 +14,7 @@
 package com.google.devtools.build.lib.collect.nestedset;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.times;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.Futures;
@@ -24,11 +25,17 @@
 import com.google.devtools.build.lib.collect.nestedset.NestedSetStore.NestedSetCache;
 import com.google.devtools.build.lib.collect.nestedset.NestedSetStore.NestedSetStorageEndpoint;
 import com.google.devtools.build.lib.skyframe.serialization.AutoRegistry;
+import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationConstants;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationResult;
 import com.google.protobuf.ByteString;
+import java.io.IOException;
 import java.nio.charset.Charset;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -222,6 +229,8 @@
     Mockito.doReturn(subset2Future)
         .when(nestedSetStorageEndpoint)
         .get(fingerprintCaptor.getAllValues().get(1));
+    Mockito.when(emptyNestedSetCache.putIfAbsent(Mockito.any(), Mockito.any()))
+        .thenAnswer(invocation -> null);
 
     ListenableFuture<Object[]> deserializationFuture =
         nestedSetStore.getContentsAndDeserialize(
@@ -236,4 +245,96 @@
     subset2Future.set(ByteString.copyFrom("mock bytes", Charset.defaultCharset()).toByteArray());
     assertThat(deserializationFuture.isDone()).isTrue();
   }
+
+  @Test
+  public void racingDeserialization() throws Exception {
+    NestedSetStorageEndpoint nestedSetStorageEndpoint =
+        Mockito.mock(NestedSetStorageEndpoint.class);
+    NestedSetCache nestedSetCache = Mockito.spy(new NestedSetCache());
+    NestedSetStore nestedSetStore =
+        new NestedSetStore(
+            nestedSetStorageEndpoint, nestedSetCache, MoreExecutors.directExecutor());
+    DeserializationContext deserializationContext = Mockito.mock(DeserializationContext.class);
+    ByteString fingerprint = ByteString.copyFromUtf8("fingerprint");
+    // Future never completes, so we don't have to exercise that code in NestedSetStore.
+    SettableFuture<byte[]> storageFuture = SettableFuture.create();
+    Mockito.when(nestedSetStorageEndpoint.get(fingerprint)).thenReturn(storageFuture);
+    CountDownLatch fingerprintRequested = new CountDownLatch(2);
+    Mockito.doAnswer(
+            invocation -> {
+              fingerprintRequested.countDown();
+              @SuppressWarnings("unchecked")
+              ListenableFuture<Object[]> result =
+                  (ListenableFuture<Object[]>) invocation.callRealMethod();
+              fingerprintRequested.await();
+              return result;
+            })
+        .when(nestedSetCache)
+        .putIfAbsent(Mockito.eq(fingerprint), Mockito.any());
+    AtomicReference<ListenableFuture<Object[]>> asyncResult = new AtomicReference<>();
+    Thread asyncThread =
+        new Thread(
+            () -> {
+              try {
+                asyncResult.set(
+                    nestedSetStore.getContentsAndDeserialize(fingerprint, deserializationContext));
+              } catch (IOException e) {
+                throw new IllegalStateException(e);
+              }
+            });
+    asyncThread.start();
+    ListenableFuture<Object[]> result =
+        nestedSetStore.getContentsAndDeserialize(fingerprint, deserializationContext);
+    asyncThread.join();
+    Mockito.verify(nestedSetStorageEndpoint, times(1)).get(Mockito.eq(fingerprint));
+    assertThat(result).isSameAs(asyncResult.get());
+    assertThat(result.isDone()).isFalse();
+  }
+
+  @Test
+  public void bugInRacingSerialization() throws Exception {
+    NestedSetStorageEndpoint nestedSetStorageEndpoint =
+        Mockito.mock(NestedSetStorageEndpoint.class);
+    NestedSetCache nestedSetCache = Mockito.spy(new NestedSetCache());
+    NestedSetStore nestedSetStore =
+        new NestedSetStore(
+            nestedSetStorageEndpoint, nestedSetCache, MoreExecutors.directExecutor());
+    SerializationContext serializationContext = Mockito.mock(SerializationContext.class);
+    Object[] contents = {new Object()};
+    Mockito.when(serializationContext.getNewMemoizingContext()).thenReturn(serializationContext);
+    Mockito.when(nestedSetStorageEndpoint.put(Mockito.any(), Mockito.any()))
+        .thenAnswer(invocation -> SettableFuture.create());
+    CountDownLatch fingerprintRequested = new CountDownLatch(2);
+    Mockito.doAnswer(
+            invocation -> {
+              fingerprintRequested.countDown();
+              NestedSetStore.FingerprintComputationResult result =
+                  (NestedSetStore.FingerprintComputationResult) invocation.callRealMethod();
+              assertThat(result).isNull();
+              fingerprintRequested.await();
+              return null;
+            })
+        .when(nestedSetCache)
+        .fingerprintForContents(contents);
+    AtomicReference<NestedSetStore.FingerprintComputationResult> asyncResult =
+        new AtomicReference<>();
+    Thread asyncThread =
+        new Thread(
+            () -> {
+              try {
+                asyncResult.set(
+                    nestedSetStore.computeFingerprintAndStore(contents, serializationContext));
+              } catch (IOException | SerializationException e) {
+                throw new IllegalStateException(e);
+              }
+            });
+    asyncThread.start();
+    NestedSetStore.FingerprintComputationResult result =
+        nestedSetStore.computeFingerprintAndStore(contents, serializationContext);
+    asyncThread.join();
+    // TODO(janakr): This should be one fetch, but we currently do two.
+    Mockito.verify(nestedSetStorageEndpoint, times(2)).put(Mockito.any(), Mockito.any());
+    // TODO(janakr): These should be the same element.
+    assertThat(result).isNotEqualTo(asyncResult.get());
+  }
 }