Allow overriding dependencies in `DeserializationContext`, and pass `DeserializationContext` to `OptionsChecksumCache`.
PiperOrigin-RevId: 402634691
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
index 35f5cd2..e917279 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
@@ -684,7 +684,7 @@
throws IOException {
String checksum = codedIn.readString();
return checkNotNull(
- context.getDependency(OptionsChecksumCache.class).getOptions(checksum),
+ context.getDependency(OptionsChecksumCache.class).getOptions(checksum, context),
"No options instance for %s",
checksum);
}
@@ -699,7 +699,7 @@
/**
* Called during deserialization to transform a checksum into a {@link BuildOptions} instance.
*/
- BuildOptions getOptions(String checksum);
+ BuildOptions getOptions(String checksum, DeserializationContext context);
/**
* Notifies the cache that it may be necessary to deserialize the given options diff's checksum.
@@ -718,7 +718,7 @@
private final ConcurrentMap<String, BuildOptions> map = new ConcurrentHashMap<>();
@Override
- public BuildOptions getOptions(String checksum) {
+ public BuildOptions getOptions(String checksum, DeserializationContext context) {
return map.get(checksum);
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
index 917fcf7..195d850 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
@@ -15,10 +15,12 @@
package com.google.devtools.build.lib.skyframe.serialization;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ClassToInstanceMap;
import com.google.common.collect.ImmutableClassToInstanceMap;
-import com.google.devtools.build.lib.skyframe.serialization.Memoizer.Deserializer;
+import com.google.common.collect.Maps;
import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec.MemoizationStrategy;
import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecRegistry.CodecDescriptor;
import com.google.protobuf.CodedInputStream;
@@ -35,12 +37,12 @@
public class DeserializationContext {
private final ObjectCodecRegistry registry;
private final ImmutableClassToInstanceMap<Object> dependencies;
- private final Memoizer.Deserializer deserializer;
+ @Nullable private final Memoizer.Deserializer deserializer;
private DeserializationContext(
ObjectCodecRegistry registry,
ImmutableClassToInstanceMap<Object> dependencies,
- Deserializer deserializer) {
+ @Nullable Memoizer.Deserializer deserializer) {
this.registry = registry;
this.dependencies = dependencies;
this.deserializer = deserializer;
@@ -106,10 +108,9 @@
* <p>This is a noop when memoization is disabled.
*/
public <T> void registerInitialValue(T initialValue) {
- if (deserializer == null) {
- return;
+ if (deserializer != null) {
+ deserializer.registerInitialValue(initialValue);
}
- deserializer.registerInitialValue(initialValue);
}
public <T> T getDependency(Class<T> type) {
@@ -134,14 +135,33 @@
}
/**
- * Returns a memoizing {@link DeserializationContext}, as getMemoizingContext above. Unlike
- * getMemoizingContext, this method is not idempotent - the returned context will always be fresh.
+ * Returns a new memoizing {@link DeserializationContext}, as {@link #getMemoizingContext}. Unlike
+ * {@link #getMemoizingContext}, this method is not idempotent - the returned context will always
+ * be fresh.
*/
public DeserializationContext getNewMemoizingContext() {
- return new DeserializationContext(this.registry, this.dependencies, new Deserializer());
+ return new DeserializationContext(registry, dependencies, new Memoizer.Deserializer());
}
- public DeserializationContext getNewNonMemoizingContext() {
- return new DeserializationContext(this.registry, this.dependencies, null);
+ /**
+ * Returns a new {@link DeserializationContext} mostly identical to this one, but with a
+ * dependency map composed by applying overrides to this context's dependencies.
+ *
+ * <p>The given {@code dependencyOverrides} may contain keys already present (in which case the
+ * dependency is replaced) or new keys (in which case the dependency is added).
+ *
+ * <p>Must only be called on a base context (no memoization state), since changing dependencies
+ * may change deserialization semantics.
+ */
+ @CheckReturnValue
+ public DeserializationContext withDependencyOverrides(ClassToInstanceMap<?> dependencyOverrides) {
+ checkState(deserializer == null, "Must only be called on base DeserializationContext");
+ return new DeserializationContext(
+ registry,
+ ImmutableClassToInstanceMap.builder()
+ .putAll(Maps.filterKeys(dependencies, k -> !dependencyOverrides.containsKey(k)))
+ .putAll(dependencyOverrides)
+ .build(),
+ /*deserializer=*/ null);
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
index 176e548..b9db2aa 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContextTest.java
@@ -15,7 +15,11 @@
package com.google.devtools.build.lib.skyframe.serialization;
import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableClassToInstanceMap;
@@ -26,45 +30,45 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import org.mockito.Mockito;
/** Tests for {@link DeserializationContext}. */
@RunWith(JUnit4.class)
-public class DeserializationContextTest {
+public final class DeserializationContextTest {
+
@Test
public void nullDeserialize() throws Exception {
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
when(codedInputStream.readSInt32()).thenReturn(0);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
assertThat((Object) deserializationContext.deserialize(codedInputStream)).isNull();
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verifyNoInteractions(registry);
+ verify(codedInputStream).readSInt32();
+ verifyNoInteractions(registry);
}
@Test
public void constantDeserialize() throws Exception {
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
Object constant = new Object();
when(registry.maybeGetConstantByTag(1)).thenReturn(constant);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
when(codedInputStream.readSInt32()).thenReturn(1);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
assertThat((Object) deserializationContext.deserialize(codedInputStream))
.isSameInstanceAs(constant);
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verify(registry).maybeGetConstantByTag(1);
+ verify(codedInputStream).readSInt32();
+ verify(registry).maybeGetConstantByTag(1);
}
@Test
public void descriptorDeserialize() throws Exception {
ObjectCodecRegistry.CodecDescriptor codecDescriptor =
- Mockito.mock(ObjectCodecRegistry.CodecDescriptor.class);
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+ mock(ObjectCodecRegistry.CodecDescriptor.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
when(registry.getCodecDescriptorByTag(1)).thenReturn(codecDescriptor);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
when(codedInputStream.readSInt32()).thenReturn(1);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
@@ -73,63 +77,122 @@
.thenReturn(returnValue);
assertThat((Object) deserializationContext.deserialize(codedInputStream))
.isSameInstanceAs(returnValue);
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verify(registry).getCodecDescriptorByTag(1);
- Mockito.verify(codecDescriptor).deserialize(deserializationContext, codedInputStream);
+ verify(codedInputStream).readSInt32();
+ verify(registry).getCodecDescriptorByTag(1);
+ verify(codecDescriptor).deserialize(deserializationContext, codedInputStream);
}
@Test
public void memoizingDeserialize_null() throws SerializationException, IOException {
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
when(codedInputStream.readSInt32()).thenReturn(0);
assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream))
.isEqualTo(null);
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verifyNoInteractions(registry);
+ verify(codedInputStream).readSInt32();
+ verifyNoInteractions(registry);
}
@Test
public void memoizingDeserialize_constant() throws SerializationException, IOException {
Object constant = new Object();
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
when(registry.maybeGetConstantByTag(1)).thenReturn(constant);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of());
when(codedInputStream.readSInt32()).thenReturn(1);
assertThat((Object) deserializationContext.getMemoizingContext().deserialize(codedInputStream))
.isEqualTo(constant);
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verify(registry).maybeGetConstantByTag(1);
+ verify(codedInputStream).readSInt32();
+ verify(registry).maybeGetConstantByTag(1);
}
@Test
public void memoizingDeserialize_codec() throws SerializationException, IOException {
Object returned = new Object();
@SuppressWarnings("unchecked")
- ObjectCodec<Object> codec = Mockito.mock(ObjectCodec.class);
+ ObjectCodec<Object> codec = mock(ObjectCodec.class);
when(codec.getStrategy()).thenReturn(MemoizationStrategy.MEMOIZE_AFTER);
when(codec.getEncodedClass()).thenAnswer(unused -> Object.class);
when(codec.additionalEncodedClasses()).thenReturn(ImmutableList.of());
ObjectCodecRegistry.CodecDescriptor codecDescriptor =
- Mockito.mock(ObjectCodecRegistry.CodecDescriptor.class);
+ mock(ObjectCodecRegistry.CodecDescriptor.class);
doReturn(codec).when(codecDescriptor).getCodec();
- ObjectCodecRegistry registry = Mockito.mock(ObjectCodecRegistry.class);
+ ObjectCodecRegistry registry = mock(ObjectCodecRegistry.class);
when(registry.getCodecDescriptorByTag(1)).thenReturn(codecDescriptor);
- CodedInputStream codedInputStream = Mockito.mock(CodedInputStream.class);
+ CodedInputStream codedInputStream = mock(CodedInputStream.class);
DeserializationContext deserializationContext =
new DeserializationContext(registry, ImmutableClassToInstanceMap.of())
.getMemoizingContext();
when(codec.deserialize(deserializationContext, codedInputStream)).thenReturn(returned);
when(codedInputStream.readSInt32()).thenReturn(1);
assertThat((Object) deserializationContext.deserialize(codedInputStream)).isEqualTo(returned);
- Mockito.verify(codedInputStream).readSInt32();
- Mockito.verify(registry).maybeGetConstantByTag(1);
- Mockito.verify(registry).getCodecDescriptorByTag(1);
- Mockito.verify(codecDescriptor).getCodec();
- Mockito.verify(codec).deserialize(deserializationContext, codedInputStream);
+ verify(codedInputStream).readSInt32();
+ verify(registry).maybeGetConstantByTag(1);
+ verify(registry).getCodecDescriptorByTag(1);
+ verify(codecDescriptor).getCodec();
+ verify(codec).deserialize(deserializationContext, codedInputStream);
+ }
+
+ @Test
+ public void getDependency() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+ assertThat(context.getDependency(String.class)).isEqualTo("abc");
+ }
+
+ @Test
+ public void getDependency_notPresent() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of());
+ Exception e =
+ assertThrows(NullPointerException.class, () -> context.getDependency(String.class));
+ assertThat(e).hasMessageThat().contains("Missing dependency of type " + String.class);
+ }
+
+ @Test
+ public void dependencyOverrides_alreadyPresent() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+ DeserializationContext overridden =
+ context.withDependencyOverrides(ImmutableClassToInstanceMap.of(String.class, "xyz"));
+ assertThat(overridden.getDependency(String.class)).isEqualTo("xyz");
+ }
+
+ @Test
+ public void dependencyOverrides_new() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+ DeserializationContext overridden =
+ context.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1));
+ assertThat(overridden.getDependency(Integer.class)).isEqualTo(1);
+ }
+
+ @Test
+ public void dependencyOverrides_unchanged() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of(String.class, "abc"));
+ DeserializationContext overridden =
+ context.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1));
+ assertThat(overridden.getDependency(String.class)).isEqualTo("abc");
+ }
+
+ @Test
+ public void dependencyOverrides_disallowedOnMemoizingContext() {
+ DeserializationContext context =
+ new DeserializationContext(
+ mock(ObjectCodecRegistry.class), ImmutableClassToInstanceMap.of());
+ DeserializationContext memoizing = context.getMemoizingContext();
+ assertThrows(
+ IllegalStateException.class,
+ () -> memoizing.withDependencyOverrides(ImmutableClassToInstanceMap.of(Integer.class, 1)));
}
}