Wire a caching context through from `NestedSetStore` to handle potentially ambiguous serialization.

`NestedSetSerializationCache` was recently updated to support the case of different contents having the same serialized representation. Use that feature from `NestedSetStore`, accepting a function to go from serialization context to cache context.

PiperOrigin-RevId: 403496100
diff --git a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecWithStore.java b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecWithStore.java
index 091bc47..2f62287 100644
--- a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecWithStore.java
+++ b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecWithStore.java
@@ -98,7 +98,7 @@
       context.serialize(obj.getChildren(), codedOut);
     } else {
       codedOut.writeEnumNoTag(NestedSetSize.NONLEAF.ordinal());
-      context.serialize(obj.getApproxDepth(), codedOut);
+      codedOut.writeInt32NoTag(obj.getApproxDepth());
       FingerprintComputationResult fingerprintComputationResult =
           nestedSetStore.computeFingerprintAndStore((Object[]) obj.getChildren(), context);
       context.addFutureToBlockWritingOn(
@@ -118,9 +118,9 @@
         return NestedSetBuilder.emptySet(order);
       case LEAF:
         Object contents = context.deserialize(codedIn);
-        return intern(order, 1, contents);
+        return intern(order, /*depth=*/ 1, contents);
       case NONLEAF:
-        int depth = context.deserialize(codedIn);
+        int depth = codedIn.readInt32();
         ByteString fingerprint = ByteString.copyFrom(codedIn.readByteArray());
         return intern(order, depth, nestedSetStore.getContentsAndDeserialize(fingerprint, context));
     }
diff --git a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
index a35ab01..c22308c 100644
--- a/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
+++ b/src/main/java/com/google/devtools/build/lib/collect/nestedset/NestedSetStore.java
@@ -29,6 +29,7 @@
 import com.google.devtools.build.lib.bugreport.BugReporter;
 import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationDependencyProvider;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.CodedInputStream;
@@ -36,10 +37,10 @@
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.time.Duration;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
+import java.util.function.Function;
 import javax.annotation.Nullable;
 
 /**
@@ -140,38 +141,50 @@
     abstract ListenableFuture<Void> writeStatus();
   }
 
-  private final NestedSetSerializationCache nestedSetCache;
+  public static final Function<SerializationDependencyProvider, ?> NO_CONTEXT = ctx -> "";
+
   private final NestedSetStorageEndpoint endpoint;
   private final Executor executor;
-
-  /** Creates a NestedSetStore with the provided {@link NestedSetStorageEndpoint} as a backend. */
-  @VisibleForTesting
-  public NestedSetStore(NestedSetStorageEndpoint endpoint) {
-    this(endpoint, directExecutor(), BugReporter.defaultInstance());
-  }
+  private final NestedSetSerializationCache nestedSetCache;
+  private final Function<SerializationDependencyProvider, ?> cacheContextFn;
 
   /**
    * Creates a NestedSetStore with the provided {@link NestedSetStorageEndpoint} and executor for
    * deserialization.
+   *
+   * <p>Takes a function that produces a caching context object from a {@link
+   * SerializationDependencyProvider}. The context should work as described in {@link
+   * NestedSetSerializationCache} to disambiguate different contents that have the same serialized
+   * representation. If a one-to-one correspondence between contents and serialized representation
+   * is guaranteed, use {@link #NO_CONTEXT}, which uses a constant object for the cache context.
    */
   public NestedSetStore(
-      NestedSetStorageEndpoint endpoint, Executor executor, BugReporter bugReporter) {
-    this(endpoint, new NestedSetSerializationCache(bugReporter), executor);
+      NestedSetStorageEndpoint endpoint,
+      Executor executor,
+      BugReporter bugReporter,
+      Function<SerializationDependencyProvider, ?> cacheContextFn) {
+    this(endpoint, executor, new NestedSetSerializationCache(bugReporter), cacheContextFn);
   }
 
   @VisibleForTesting
   NestedSetStore(
       NestedSetStorageEndpoint endpoint,
+      Executor executor,
       NestedSetSerializationCache nestedSetCache,
-      Executor executor) {
+      Function<SerializationDependencyProvider, ?> cacheContextFn) {
     this.endpoint = checkNotNull(endpoint);
-    this.nestedSetCache = checkNotNull(nestedSetCache);
     this.executor = checkNotNull(executor);
+    this.nestedSetCache = checkNotNull(nestedSetCache);
+    this.cacheContextFn = checkNotNull(cacheContextFn);
   }
 
-  /** Creates a NestedSetStore with an in-memory storage backend. */
+  /** Creates a NestedSetStore with an in-memory storage backend and no caching context. */
   public static NestedSetStore inMemory() {
-    return new NestedSetStore(new InMemoryNestedSetStorageEndpoint());
+    return new NestedSetStore(
+        new InMemoryNestedSetStorageEndpoint(),
+        directExecutor(),
+        BugReporter.defaultInstance(),
+        NO_CONTEXT);
   }
 
   /**
@@ -192,6 +205,13 @@
   FingerprintComputationResult computeFingerprintAndStore(
       Object[] contents, SerializationContext serializationContext)
       throws SerializationException, IOException {
+    return computeFingerprintAndStore(
+        contents, serializationContext, cacheContextFn.apply(serializationContext));
+  }
+
+  private FingerprintComputationResult computeFingerprintAndStore(
+      Object[] contents, SerializationContext serializationContext, Object cacheContext)
+      throws SerializationException, IOException {
     FingerprintComputationResult priorFingerprint = nestedSetCache.fingerprintForContents(contents);
     if (priorFingerprint != null) {
       return priorFingerprint;
@@ -213,7 +233,7 @@
       for (Object child : contents) {
         if (child instanceof Object[]) {
           FingerprintComputationResult fingerprintComputationResult =
-              computeFingerprintAndStore((Object[]) child, serializationContext);
+              computeFingerprintAndStore((Object[]) child, serializationContext, cacheContext);
           futureBuilder.add(fingerprintComputationResult.writeStatus());
           newSerializationContext.serialize(
               fingerprintComputationResult.fingerprint(), codedOutputStream);
@@ -244,9 +264,8 @@
     FingerprintComputationResult result =
         FingerprintComputationResult.create(fingerprint, writeFuture);
 
-    // TODO(b/202438580): Pass through relevant context.
     FingerprintComputationResult existingResult =
-        nestedSetCache.putIfAbsent(contents, result, /*context=*/ "");
+        nestedSetCache.putIfAbsent(contents, result, cacheContext);
     if (existingResult != null) {
       return existingResult; // Another thread won the fingerprint computation race.
     }
@@ -275,13 +294,19 @@
    * <p>The return value is either an {@code Object[]} or a {@code ListenableFuture<Object[]>},
    * which may be completed with a {@link MissingNestedSetException}.
    */
-  // All callers will test on type and check return value if it's a future.
-  @SuppressWarnings("FutureReturnValueIgnored")
   Object getContentsAndDeserialize(
       ByteString fingerprint, DeserializationContext deserializationContext) throws IOException {
+    return getContentsAndDeserialize(
+        fingerprint, deserializationContext, cacheContextFn.apply(deserializationContext));
+  }
+
+  // All callers will test on type and check return value if it's a future.
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private Object getContentsAndDeserialize(
+      ByteString fingerprint, DeserializationContext deserializationContext, Object cacheContext)
+      throws IOException {
     SettableFuture<Object[]> future = SettableFuture.create();
-    // TODO(b/202438580): Pass through relevant context.
-    Object contents = nestedSetCache.putFutureIfAbsent(fingerprint, future, /*context=*/ "");
+    Object contents = nestedSetCache.putFutureIfAbsent(fingerprint, future, cacheContext);
     if (contents != null) {
       return contents;
     }
@@ -304,32 +329,25 @@
                   deserializationContext.getNewMemoizingContext();
 
               // The elements of this list are futures for the deserialized values of these
-              // NestedSet contents.  For direct members, the futures complete immediately and yield
-              // an Object.  For transitive members (fingerprints), the futures complete with the
+              // NestedSet contents. For direct members, the futures complete immediately and yield
+              // an Object. For transitive members (fingerprints), the futures complete with the
               // underlying fetch, and yield Object[]s.
-              List<ListenableFuture<?>> deserializationFutures = new ArrayList<>();
+              ImmutableList.Builder<ListenableFuture<?>> deserializationFutures =
+                  ImmutableList.builderWithExpectedSize(numberOfElements);
               for (int i = 0; i < numberOfElements; i++) {
                 Object deserializedElement = newDeserializationContext.deserialize(codedIn);
                 if (deserializedElement instanceof ByteString) {
-                  deserializationFutures.add(
-                      maybeWrapInFuture(
-                          getContentsAndDeserialize(
-                              (ByteString) deserializedElement, deserializationContext)));
+                  Object innerContents =
+                      getContentsAndDeserialize(
+                          (ByteString) deserializedElement, deserializationContext, cacheContext);
+                  deserializationFutures.add(maybeWrapInFuture(innerContents));
                 } else {
                   deserializationFutures.add(Futures.immediateFuture(deserializedElement));
                 }
               }
 
-              return Futures.whenAllComplete(deserializationFutures)
-                  .call(
-                      () -> {
-                        Object[] deserializedContents = new Object[deserializationFutures.size()];
-                        for (int i = 0; i < deserializationFutures.size(); i++) {
-                          deserializedContents[i] = Futures.getDone(deserializationFutures.get(i));
-                        }
-                        return deserializedContents;
-                      },
-                      executor);
+              return Futures.transform(
+                  Futures.allAsList(deserializationFutures.build()), List::toArray, executor);
             },
             executor));
     return future;
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
index 195d850..be28bac 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
@@ -34,7 +34,7 @@
  * thread-safe and should only be accessed on a single thread for deserializing one serialized
  * object (that may contain other serialized objects inside it).
  */
-public class DeserializationContext {
+public class DeserializationContext implements SerializationDependencyProvider {
   private final ObjectCodecRegistry registry;
   private final ImmutableClassToInstanceMap<Object> dependencies;
   @Nullable private final Memoizer.Deserializer deserializer;
@@ -113,6 +113,7 @@
     }
   }
 
+  @Override
   public <T> T getDependency(Class<T> type) {
     return checkNotNull(dependencies.getInstance(type), "Missing dependency of type %s", type);
   }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
index ee12af0..f5c188966d 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
@@ -20,6 +20,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableClassToInstanceMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -43,7 +44,7 @@
  * should only be accessed on a single thread for serializing one object (that may involve
  * serializing other objects contained in it).
  */
-public class SerializationContext {
+public class SerializationContext implements SerializationDependencyProvider {
   private final ObjectCodecRegistry registry;
   private final ImmutableClassToInstanceMap<Object> dependencies;
   @Nullable private final Memoizer.Serializer serializer;
@@ -108,6 +109,7 @@
     }
   }
 
+  @Override
   public <T> T getDependency(Class<T> type) {
     return checkNotNull(dependencies.getInstance(type), "Missing dependency of type %s", type);
   }
@@ -152,12 +154,31 @@
 
   private SerializationContext getNewMemoizingContext(boolean allowFuturesToBlockWritingOn) {
     return new SerializationContext(
-        this.registry, this.dependencies, new Memoizer.Serializer(), allowFuturesToBlockWritingOn);
+        registry, dependencies, new Memoizer.Serializer(), allowFuturesToBlockWritingOn);
   }
 
-  public SerializationContext getNewNonMemoizingContext() {
+  /**
+   * Returns a new {@link SerializationContext} mostly identical to this one, but with a dependency
+   * map composed by applying overrides to this context's dependencies.
+   *
+   * <p>The given {@code dependencyOverrides} may contain keys already present (in which case the
+   * dependency will be replaced) or new keys (in which case the dependency will be added).
+   *
+   * <p>Must only be called on a base context (no memoization state), since changing dependencies
+   * may change deserialization semantics.
+   */
+  @CheckReturnValue
+  public SerializationContext withDependencyOverrides(
+      ImmutableClassToInstanceMap<?> dependencyOverrides) {
+    checkState(serializer == null, "Must only be called on base SerializationContext");
     return new SerializationContext(
-        this.registry, this.dependencies, null, this.allowFuturesToBlockWritingOn);
+        registry,
+        ImmutableClassToInstanceMap.builder()
+            .putAll(Maps.filterKeys(dependencies, k -> !dependencyOverrides.containsKey(k)))
+            .putAll(dependencyOverrides)
+            .build(),
+        /*serializer=*/ null,
+        allowFuturesToBlockWritingOn);
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationDependencyProvider.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationDependencyProvider.java
new file mode 100644
index 0000000..c776f39
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationDependencyProvider.java
@@ -0,0 +1,29 @@
+// Copyright 2021 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.skyframe.serialization;
+
+/**
+ * Common interface for {@link SerializationContext} and {@link DeserializationContext}, which both
+ * provide access to dependencies required for serialization.
+ */
+public interface SerializationDependencyProvider {
+
+  /**
+   * Returns the dependency associated with the given type.
+   *
+   * @throws NullPointerException if there is no dependency registered for the given type
+   */
+  <T> T getDependency(Class<T> type);
+}
diff --git a/src/test/java/com/google/devtools/build/lib/collect/BUILD b/src/test/java/com/google/devtools/build/lib/collect/BUILD
index 2c3b447..906cabe 100644
--- a/src/test/java/com/google/devtools/build/lib/collect/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/collect/BUILD
@@ -29,9 +29,11 @@
         "//src/main/java/com/google/devtools/build/lib/collect/nestedset:testutils",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization",
         "//src/main/java/com/google/devtools/build/lib/util",
+        "//src/main/java/com/google/devtools/build/lib/util/io",
         "//src/main/java/com/google/devtools/build/lib/vfs",
         "//src/main/java/net/starlark/java/eval",
         "//src/test/java/com/google/devtools/build/lib/starlark/util",
+        "//third_party:auto_value",
         "//third_party:guava",
         "//third_party:guava-testlib",
         "//third_party:junit4",
diff --git a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
index e9bb090..f627a51 100644
--- a/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/collect/nestedset/NestedSetCodecTest.java
@@ -28,7 +28,10 @@
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import com.google.auto.value.AutoValue;
 import com.google.common.collect.ImmutableClassToInstanceMap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 import com.google.common.testing.GcFinalization;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
@@ -38,17 +41,25 @@
 import com.google.devtools.build.lib.collect.nestedset.NestedSetStore.NestedSetStorageEndpoint;
 import com.google.devtools.build.lib.skyframe.serialization.AutoRegistry;
 import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec;
+import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecRegistry;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationDependencyProvider;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationResult;
+import com.google.devtools.build.lib.util.io.AnsiTerminal.Color;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
 import java.io.IOException;
 import java.lang.ref.WeakReference;
 import java.nio.charset.Charset;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -95,12 +106,12 @@
           }
         };
 
-    ObjectCodecs serializer = createCodecs(new NestedSetStore(endpoint));
+    ObjectCodecs serializer = createCodecs(createStore(endpoint));
     ByteString serializedBase = serializer.serializeMemoizedAndBlocking(base).getObject();
     ByteString serializedTop = serializer.serializeMemoizedAndBlocking(top).getObject();
 
     // When deserializing top, we should perform 2 reads, one for each array in [[a, b], c].
-    ObjectCodecs deserializer = createCodecs(new NestedSetStore(endpoint));
+    ObjectCodecs deserializer = createCodecs(createStore(endpoint));
     NestedSet<?> deserializedTop = (NestedSet<?>) deserializer.deserializeMemoized(serializedTop);
     assertThat(deserializedTop.toList()).containsExactly("a", "b", "c");
     assertThat(reads.get()).isEqualTo(2);
@@ -115,6 +126,8 @@
 
   @Test
   public void missingNestedSetException_hiddenUntilNestedSetIsConsumed() throws Exception {
+    MissingNestedSetException missingNestedSetException =
+        new MissingNestedSetException(ByteString.copyFromUtf8("fingerprint"));
     NestedSetStorageEndpoint storageEndpoint =
         new NestedSetStorageEndpoint() {
           @Override
@@ -124,12 +137,13 @@
 
           @Override
           public ListenableFuture<byte[]> get(ByteString fingerprint) {
-            return immediateFailedFuture(
-                new MissingNestedSetException(ByteString.copyFromUtf8("fingerprint")));
+            return immediateFailedFuture(missingNestedSetException);
           }
         };
-    ObjectCodecs serializer = createCodecs(new NestedSetStore(storageEndpoint));
-    ObjectCodecs deserializer = createCodecs(new NestedSetStore(storageEndpoint));
+    BugReporter bugReporter = mock(BugReporter.class);
+    ObjectCodecs serializer = createCodecs(createStore(storageEndpoint));
+    ObjectCodecs deserializer =
+        createCodecs(createStoreWithBugReporter(storageEndpoint, bugReporter));
 
     NestedSet<?> serialized = NestedSetBuilder.create(Order.STABLE_ORDER, "a", "b");
     SerializationResult<ByteString> result = serializer.serializeMemoizedAndBlocking(serialized);
@@ -138,6 +152,7 @@
     assertThat(deserialized).isInstanceOf(NestedSet.class);
     assertThrows(
         MissingNestedSetException.class, ((NestedSet<?>) deserialized)::toListInterruptibly);
+    verify(bugReporter).sendBugReport(missingNestedSetException);
   }
 
   @Test
@@ -154,8 +169,8 @@
             return immediateFailedFuture(new RuntimeException("Something went wrong"));
           }
         };
-    ObjectCodecs serializer = createCodecs(new NestedSetStore(storageEndpoint));
-    ObjectCodecs deserializer = createCodecs(new NestedSetStore(storageEndpoint));
+    ObjectCodecs serializer = createCodecs(createStore(storageEndpoint));
+    ObjectCodecs deserializer = createCodecs(createStore(storageEndpoint));
 
     NestedSet<?> serialized = NestedSetBuilder.create(Order.STABLE_ORDER, "a", "b");
     SerializationResult<ByteString> result = serializer.serializeMemoizedAndBlocking(serialized);
@@ -182,7 +197,7 @@
         .thenReturn(innerWrite)
         // The write of the outer NestedSet {{"a", "b"}, {"c", "d"}}
         .thenReturn(outerWrite);
-    ObjectCodecs objectCodecs = createCodecs(new NestedSetStore(mockStorage));
+    ObjectCodecs objectCodecs = createCodecs(createStore(mockStorage));
 
     NestedSet<NestedSet<String>> nestedNestedSet =
         NestedSetBuilder.create(
@@ -212,7 +227,7 @@
         .thenReturn(outerWrite)
         // The write of the inner NestedSet {"e", "f"}
         .thenReturn(immediateVoidFuture());
-    ObjectCodecs objectCodecs = createCodecs(new NestedSetStore(mockStorage));
+    ObjectCodecs objectCodecs = createCodecs(createStore(mockStorage));
 
     NestedSet<String> sharedInnerNestedSet = NestedSetBuilder.create(Order.STABLE_ORDER, "a", "b");
     NestedSet<NestedSet<String>> nestedNestedSet1 =
@@ -255,7 +270,7 @@
     // Avoid NestedSetBuilder.wrap/create - they use their own cache which interferes with what
     // we're testing.
     NestedSet<?> nestedSet = NestedSetBuilder.stableOrder().add("a").add("b").build();
-    ObjectCodecs codecs = createCodecs(new NestedSetStore(new InMemoryNestedSetStorageEndpoint()));
+    ObjectCodecs codecs = createCodecs(createStore(new InMemoryNestedSetStorageEndpoint()));
     codecs.serializeMemoizedAndBlocking(nestedSet);
     WeakReference<?> ref = new WeakReference<>(nestedSet);
     nestedSet = null;
@@ -267,7 +282,7 @@
     NestedSetStorageEndpoint nestedSetStorageEndpoint = spy(new InMemoryNestedSetStorageEndpoint());
     NestedSetSerializationCache emptyNestedSetCache = mock(NestedSetSerializationCache.class);
     NestedSetStore nestedSetStore =
-        new NestedSetStore(nestedSetStorageEndpoint, emptyNestedSetCache, directExecutor());
+        createStoreWithCache(nestedSetStorageEndpoint, emptyNestedSetCache);
 
     ObjectCodecs objectCodecs = createCodecs(nestedSetStore);
 
@@ -321,8 +336,7 @@
     NestedSetStorageEndpoint nestedSetStorageEndpoint = mock(NestedSetStorageEndpoint.class);
     NestedSetSerializationCache nestedSetCache =
         spy(new NestedSetSerializationCache(BugReporter.defaultInstance()));
-    NestedSetStore nestedSetStore =
-        new NestedSetStore(nestedSetStorageEndpoint, nestedSetCache, directExecutor());
+    NestedSetStore nestedSetStore = createStoreWithCache(nestedSetStorageEndpoint, nestedSetCache);
     DeserializationContext deserializationContext = mock(DeserializationContext.class);
     ByteString fingerprint = ByteString.copyFromUtf8("fingerprint");
     // Future never completes, so we don't have to exercise that code in NestedSetStore.
@@ -371,8 +385,7 @@
     NestedSetStorageEndpoint nestedSetStorageEndpoint = mock(NestedSetStorageEndpoint.class);
     NestedSetSerializationCache nestedSetCache =
         spy(new NestedSetSerializationCache(BugReporter.defaultInstance()));
-    NestedSetStore nestedSetStore =
-        new NestedSetStore(nestedSetStorageEndpoint, nestedSetCache, directExecutor());
+    NestedSetStore nestedSetStore = createStore(nestedSetStorageEndpoint);
     SerializationContext serializationContext = mock(SerializationContext.class);
     Object[] contents = {new Object()};
     when(serializationContext.getNewMemoizingContext()).thenReturn(serializationContext);
@@ -414,7 +427,7 @@
   @Test
   public void writeFuturesWaitForTransitiveWrites() throws Exception {
     NestedSetStorageEndpoint mockWriter = mock(NestedSetStorageEndpoint.class);
-    NestedSetStore store = new NestedSetStore(mockWriter);
+    NestedSetStore store = createStore(mockWriter);
     SerializationContext mockSerializationContext = mock(SerializationContext.class);
     when(mockSerializationContext.getNewMemoizingContext()).thenReturn(mockSerializationContext);
 
@@ -461,13 +474,137 @@
     assertThat(topWriteFuture.isDone()).isTrue();
   }
 
-  private static ObjectCodecs createCodecs(NestedSetStore store) {
-    return new ObjectCodecs(
+  @AutoValue
+  abstract static class ColorfulThing {
+    abstract String thing();
+
+    abstract Color color();
+
+    static ColorfulThing of(String thing, Color color) {
+      return new AutoValue_NestedSetCodecTest_ColorfulThing(thing, color);
+    }
+  }
+
+  @Test
+  public void cacheContext_disambiguatesIdenticalSerializedRepresentation() throws Exception {
+    // Serializes ColorfulThing without color, reading the color as a deserialization dependency.
+    class BlackAndWhiteCodec implements ObjectCodec<ColorfulThing> {
+      @Override
+      public Class<ColorfulThing> getEncodedClass() {
+        return ColorfulThing.class;
+      }
+
+      @Override
+      public void serialize(
+          SerializationContext context, ColorfulThing obj, CodedOutputStream codedOut)
+          throws SerializationException, IOException {
+        context.serialize(obj.thing(), codedOut);
+      }
+
+      @Override
+      public ColorfulThing deserialize(DeserializationContext context, CodedInputStream codedIn)
+          throws SerializationException, IOException {
+        String thing = context.deserialize(codedIn);
+        Color color = context.getDependency(Color.class);
+        return ColorfulThing.of(thing, color);
+      }
+    }
+
+    ObjectCodecs codecs =
+        createCodecs(
+            createStoreWithCacheContext(
+                new InMemoryNestedSetStorageEndpoint(), ctx -> ctx.getDependency(Color.class)),
+            new BlackAndWhiteCodec());
+
+    List<String> stuff = ImmutableList.of("bird", "paint", "shoes");
+    NestedSet<ColorfulThing> redStuff =
+        NestedSetBuilder.wrap(
+            Order.STABLE_ORDER,
+            Lists.transform(stuff, thing -> ColorfulThing.of(thing, Color.RED)));
+    NestedSet<ColorfulThing> blueStuff =
+        NestedSetBuilder.wrap(
+            Order.STABLE_ORDER,
+            Lists.transform(stuff, thing -> ColorfulThing.of(thing, Color.BLUE)));
+
+    ByteString redSerialized =
+        ObjectCodecs.serialize(
+                redStuff,
+                codecs
+                    .getSerializationContext()
+                    .withDependencyOverrides(ImmutableClassToInstanceMap.of(Color.class, Color.RED))
+                    .getMemoizingAndBlockingOnWriteContext())
+            .getObject();
+    ByteString blueSerialized =
+        ObjectCodecs.serialize(
+                blueStuff,
+                codecs
+                    .getSerializationContext()
+                    .withDependencyOverrides(
+                        ImmutableClassToInstanceMap.of(Color.class, Color.BLUE))
+                    .getMemoizingAndBlockingOnWriteContext())
+            .getObject();
+    assertThat(redSerialized).isEqualTo(blueSerialized);
+
+    Object redDeserialized =
+        ObjectCodecs.deserialize(
+            redSerialized.newCodedInput(),
+            codecs
+                .getDeserializationContext()
+                .withDependencyOverrides(ImmutableClassToInstanceMap.of(Color.class, Color.RED))
+                .getMemoizingContext());
+    Object blueDeserialized =
+        ObjectCodecs.deserialize(
+            blueSerialized.newCodedInput(),
+            codecs
+                .getDeserializationContext()
+                .withDependencyOverrides(ImmutableClassToInstanceMap.of(Color.class, Color.BLUE))
+                .getMemoizingContext());
+    assertThat(redDeserialized).isSameInstanceAs(redStuff);
+    assertThat(blueDeserialized).isSameInstanceAs(blueStuff);
+
+    // Test that we can deserialize in a context that was not previously serialized.
+    Object greenDeserialized =
+        ObjectCodecs.deserialize(
+            redSerialized.newCodedInput(),
+            codecs
+                .getDeserializationContext()
+                .withDependencyOverrides(ImmutableClassToInstanceMap.of(Color.class, Color.GREEN))
+                .getMemoizingContext());
+    assertThat(greenDeserialized).isInstanceOf(NestedSet.class);
+    assertThat(((NestedSet<?>) greenDeserialized).toList())
+        .isEqualTo(Lists.transform(stuff, thing -> ColorfulThing.of(thing, Color.GREEN)));
+  }
+
+  private static NestedSetStore createStore(NestedSetStorageEndpoint endpoint) {
+    return createStoreWithBugReporter(endpoint, BugReporter.defaultInstance());
+  }
+
+  private static NestedSetStore createStoreWithBugReporter(
+      NestedSetStorageEndpoint endpoint, BugReporter bugReporter) {
+    return new NestedSetStore(endpoint, directExecutor(), bugReporter, NestedSetStore.NO_CONTEXT);
+  }
+
+  private static NestedSetStore createStoreWithCache(
+      NestedSetStorageEndpoint endpoint, NestedSetSerializationCache cache) {
+    return new NestedSetStore(endpoint, directExecutor(), cache, NestedSetStore.NO_CONTEXT);
+  }
+
+  private static NestedSetStore createStoreWithCacheContext(
+      NestedSetStorageEndpoint endpoint,
+      Function<SerializationDependencyProvider, ?> cacheContextFn) {
+    return new NestedSetStore(
+        endpoint, directExecutor(), BugReporter.defaultInstance(), cacheContextFn);
+  }
+
+  private static ObjectCodecs createCodecs(NestedSetStore store, ObjectCodec<?>... codecs) {
+    ObjectCodecRegistry.Builder registry =
         AutoRegistry.get()
             .getBuilder()
             .setAllowDefaultCodec(true)
-            .add(new NestedSetCodecWithStore(store))
-            .build(),
-        /*dependencies=*/ ImmutableClassToInstanceMap.of());
+            .add(new NestedSetCodecWithStore(store));
+    for (ObjectCodec<?> codec : codecs) {
+      registry.add(codec);
+    }
+    return new ObjectCodecs(registry.build(), /*dependencies=*/ ImmutableClassToInstanceMap.of());
   }
 }