Allow overriding dependencies in `DeserializationContext`, and pass `DeserializationContext` to `OptionsChecksumCache`.

PiperOrigin-RevId: 402634691
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
index 35f5cd2..e917279 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
@@ -684,7 +684,7 @@
         throws IOException {
       String checksum = codedIn.readString();
       return checkNotNull(
-          context.getDependency(OptionsChecksumCache.class).getOptions(checksum),
+          context.getDependency(OptionsChecksumCache.class).getOptions(checksum, context),
           "No options instance for %s",
           checksum);
     }
@@ -699,7 +699,7 @@
     /**
      * Called during deserialization to transform a checksum into a {@link BuildOptions} instance.
      */
-    BuildOptions getOptions(String checksum);
+    BuildOptions getOptions(String checksum, DeserializationContext context);
 
     /**
      * Notifies the cache that it may be necessary to deserialize the given options diff's checksum.
@@ -718,7 +718,7 @@
     private final ConcurrentMap<String, BuildOptions> map = new ConcurrentHashMap<>();
 
     @Override
-    public BuildOptions getOptions(String checksum) {
+    public BuildOptions getOptions(String checksum, DeserializationContext context) {
       return map.get(checksum);
     }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
index 917fcf7..195d850 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
@@ -15,10 +15,12 @@
 package com.google.devtools.build.lib.skyframe.serialization;
 
 import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ClassToInstanceMap;
 import com.google.common.collect.ImmutableClassToInstanceMap;
-import com.google.devtools.build.lib.skyframe.serialization.Memoizer.Deserializer;
+import com.google.common.collect.Maps;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec.MemoizationStrategy;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecRegistry.CodecDescriptor;
 import com.google.protobuf.CodedInputStream;
@@ -35,12 +37,12 @@
 public class DeserializationContext {
   private final ObjectCodecRegistry registry;
   private final ImmutableClassToInstanceMap<Object> dependencies;
-  private final Memoizer.Deserializer deserializer;
+  @Nullable private final Memoizer.Deserializer deserializer;
 
   private DeserializationContext(
       ObjectCodecRegistry registry,
       ImmutableClassToInstanceMap<Object> dependencies,
-      Deserializer deserializer) {
+      @Nullable Memoizer.Deserializer deserializer) {
     this.registry = registry;
     this.dependencies = dependencies;
     this.deserializer = deserializer;
@@ -106,10 +108,9 @@
    * <p>This is a noop when memoization is disabled.
    */
   public <T> void registerInitialValue(T initialValue) {
-    if (deserializer == null) {
-      return;
+    if (deserializer != null) {
+      deserializer.registerInitialValue(initialValue);
     }
-    deserializer.registerInitialValue(initialValue);
   }
 
   public <T> T getDependency(Class<T> type) {
@@ -134,14 +135,33 @@
   }
 
   /**
-   * Returns a memoizing {@link DeserializationContext}, as getMemoizingContext above. Unlike
-   * getMemoizingContext, this method is not idempotent - the returned context will always be fresh.
+   * Returns a new memoizing {@link DeserializationContext}, as {@link #getMemoizingContext}. Unlike
+   * {@link #getMemoizingContext}, this method is not idempotent - the returned context will always
+   * be fresh.
    */
   public DeserializationContext getNewMemoizingContext() {
-    return new DeserializationContext(this.registry, this.dependencies, new Deserializer());
+    return new DeserializationContext(registry, dependencies, new Memoizer.Deserializer());
   }
 
-  public DeserializationContext getNewNonMemoizingContext() {
-    return new DeserializationContext(this.registry, this.dependencies, null);
+  /**
+   * Returns a new {@link DeserializationContext} mostly identical to this one, but with a
+   * dependency map composed by applying overrides to this context's dependencies.
+   *
+   * <p>The given {@code dependencyOverrides} may contain keys already present (in which case the
+   * dependency is replaced) or new keys (in which case the dependency is added).
+   *
+   * <p>Must only be called on a base context (no memoization state), since changing dependencies
+   * may change deserialization semantics.
+   */
+  @CheckReturnValue
+  public DeserializationContext withDependencyOverrides(ClassToInstanceMap<?> dependencyOverrides) {
+    checkState(deserializer == null, "Must only be called on base DeserializationContext");
+    return new DeserializationContext(
+        registry,
+        ImmutableClassToInstanceMap.builder()
+            .putAll(Maps.filterKeys(dependencies, k -> !dependencyOverrides.containsKey(k)))
+            .putAll(dependencyOverrides)
+            .build(),
+        /*deserializer=*/ null);
   }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
index 176e548..b9db2aa 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
@@ -15,7 +15,11 @@
 package com.google.devtools.build.lib.skyframe.serialization;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 
 import com.google.common.collect.ImmutableClassToInstanceMap;
@@ -26,45 +30,45 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
-import org.mockito.Mockito;
 
 /** Tests for {@link DeserializationContext}. */
 @RunWith(JUnit4.class)
-public class DeserializationContextTest {
+public final class DeserializationContextTest {
+
   @Test
   public void nullDeserialize() throws Exception {
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     when(codedInputStream.readSInt32()).thenReturn(0);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
     assertThat((Object) deserializationContext.deserialize(codedInputStream)).isNull();
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verifyNoInteractions(registry);
+    verify(codedInputStream).readSInt32();
+    verifyNoInteractions(registry);
   }
 
   @Test
   public void constantDeserialize() throws Exception {
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
     Object constant = new Object();
     when(registry.maybeGetConstantByTag(1)).thenReturn(constant);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     when(codedInputStream.readSInt32()).thenReturn(1);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
     assertThat((Object) deserializationContext.deserialize(codedInputStream))
         .isSameInstanceAs(constant);
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verify(registry).maybeGetConstantByTag(1);
+    verify(codedInputStream).readSInt32();
+    verify(registry).maybeGetConstantByTag(1);
   }
 
   @Test
   public void descriptorDeserialize() throws Exception {
     ObjectCodecRegistry.CodecDescriptor codecDescriptor =
-        Mockito.mock(ObjectCodecRegistry.CodecDescriptor.class);
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+        mock(ObjectCodecRegistry.CodecDescriptor.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
     when(registry.getCodecDescriptorByTag(1)).thenReturn(codecDescriptor);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     when(codedInputStream.readSInt32()).thenReturn(1);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
@@ -73,63 +77,122 @@
         .thenReturn(returnValue);
     assertThat((Object) deserializationContext.deserialize(codedInputStream))
         .isSameInstanceAs(returnValue);
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verify(registry).getCodecDescriptorByTag(1);
-    Mockito.verify(codecDescriptor).deserialize(deserializationContext, codedInputStream);
+    verify(codedInputStream).readSInt32();
+    verify(registry).getCodecDescriptorByTag(1);
+    verify(codecDescriptor).deserialize(deserializationContext, codedInputStream);
   }
 
   @Test
   public void memoizingDeserialize_null() throws SerializationException, IOException {
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
     when(codedInputStream.readSInt32()).thenReturn(0);
     assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream))
         .isEqualTo(null);
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verifyNoInteractions(registry);
+    verify(codedInputStream).readSInt32();
+    verifyNoInteractions(registry);
   }
 
   @Test
   public void memoizingDeserialize_constant() throws SerializationException, IOException {
     Object constant = new Object();
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
     when(registry.maybeGetConstantByTag(1)).thenReturn(constant);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
     when(codedInputStream.readSInt32()).thenReturn(1);
     assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream))
         .isEqualTo(constant);
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verify(registry).maybeGetConstantByTag(1);
+    verify(codedInputStream).readSInt32();
+    verify(registry).maybeGetConstantByTag(1);
   }
 
   @Test
   public void memoizingDeserialize_codec() throws SerializationException, IOException {
     Object returned = new Object();
     @SuppressWarnings("unchecked")
-    ObjectCodec<Object> codec = Mockito.mock(ObjectCodec.class);
+    ObjectCodec<Object> codec = mock(ObjectCodec.class);
     when(codec.getStrategy()).thenReturn(MemoizationStrategy.MEMOIZE_AFTER);
     when(codec.getEncodedClass()).thenAnswer(unused -> Object.class);
     when(codec.additionalEncodedClasses()).thenReturn(ImmutableList.of());
     ObjectCodecRegistry.CodecDescriptor codecDescriptor =
-        Mockito.mock(ObjectCodecRegistry.CodecDescriptor.class);
+        mock(ObjectCodecRegistry.CodecDescriptor.class);
     doReturn(codec).when(codecDescriptor).getCodec();
-    ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+    ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
     when(registry.getCodecDescriptorByTag(1)).thenReturn(codecDescriptor);
-    CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+    CodedInputStream codedInputStream = mock(CodedInputStream.class);
     DeserializationContext deserializationContext =
         new DeserializationContext(registry, ImmutableClassToInstanceMap.of())
             .getMemoizingContext();
     when(codec.deserialize(deserializationContext, codedInputStream)).thenReturn(returned);
     when(codedInputStream.readSInt32()).thenReturn(1);
     assertThat((Object) deserializationContext.deserialize(codedInputStream)).isEqualTo(returned);
-    Mockito.verify(codedInputStream).readSInt32();
-    Mockito.verify(registry).maybeGetConstantByTag(1);
-    Mockito.verify(registry).getCodecDescriptorByTag(1);
-    Mockito.verify(codecDescriptor).getCodec();
-    Mockito.verify(codec).deserialize(deserializationContext, codedInputStream);
+    verify(codedInputStream).readSInt32();
+    verify(registry).maybeGetConstantByTag(1);
+    verify(registry).getCodecDescriptorByTag(1);
+    verify(codecDescriptor).getCodec();
+    verify(codec).deserialize(deserializationContext, codedInputStream);
+  }
+
+  @Test
+  public void getDependency() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+    assertThat(context.getDependency(String.class)).isEqualTo("abc");
+  }
+
+  @Test
+  public void getDependency_notPresent() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of());
+    Exception e =
+        assertThrows(NullPointerException.class, () -> context.getDependency(String.class));
+    assertThat(e).hasMessageThat().contains("Missing dependency of type " + String.class);
+  }
+
+  @Test
+  public void dependencyOverrides_alreadyPresent() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+    DeserializationContext overridden =
+        context.withDependencyOverrides(ImmutableClassToInstanceMap.of(String.class, "xyz"));
+    assertThat(overridden.getDependency(String.class)).isEqualTo("xyz");
+  }
+
+  @Test
+  public void dependencyOverrides_new() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+    DeserializationContext overridden =
+        context.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1));
+    assertThat(overridden.getDependency(Integer.class)).isEqualTo(1);
+  }
+
+  @Test
+  public void dependencyOverrides_unchanged() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+    DeserializationContext overridden =
+        context.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1));
+    assertThat(overridden.getDependency(String.class)).isEqualTo("abc");
+  }
+
+  @Test
+  public void dependencyOverrides_disallowedOnMemoizingContext() {
+    DeserializationContext context =
+        new DeserializationContext(
+            mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of());
+    DeserializationContext memoizing = context.getMemoizingContext();
+    assertThrows(
+        IllegalStateException.class,
+        () -> memoizing.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1)));
   }
 }