Add functionality to SerializationContext and @AutoCodec to check that a class is allowed to be serialized in the current context. A codec can now add an explicitly allowed class that can be serialized underneath it (via SerializationContext#addExplicitlyAllowedClass), and that class's codec can check that it is explicitly allowed (via SerializationContext#checkClassExplicitlyAllowed). It is a runtime crash if a codec checks that it was explicitly allowed and finds that it wasn't. Thus, if PackageCodec is invoked without it having been explicitly allowed, we will crash, preventing Package from sneaking into a value it shouldn't be in.

This is only enabled if the codec is memoizing.

PiperOrigin-RevId: 199317936
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/configuredtargets/RuleConfiguredTarget.java b/src/main/java/com/google/devtools/build/lib/analysis/configuredtargets/RuleConfiguredTarget.java
index 54f3368..f3f03a8 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/configuredtargets/RuleConfiguredTarget.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/configuredtargets/RuleConfiguredTarget.java
@@ -60,7 +60,7 @@
  * analyzed rule. For more information about how analysis works, see {@link
  * com.google.devtools.build.lib.analysis.RuleConfiguredTargetFactory}.
  */
-@AutoCodec
+@AutoCodec(checkClassExplicitlyAllowed = true)
 public final class RuleConfiguredTarget extends AbstractConfiguredTarget {
   private static final String ACTIONS_FIELD_NAME = "actions";
 
diff --git a/src/main/java/com/google/devtools/build/lib/packages/Package.java b/src/main/java/com/google/devtools/build/lib/packages/Package.java
index f0f79d4..1ea31a1 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/Package.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/Package.java
@@ -1645,6 +1645,7 @@
         Package input,
         CodedOutputStream codedOut)
         throws IOException, SerializationException {
+      context.checkClassExplicitlyAllowed(Package.class);
       PackageCodecDependencies codecDeps = context.getDependency(PackageCodecDependencies.class);
       codecDeps.getPackageSerializer().serialize(context, input, codedOut);
     }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
index 1d9d1a6..d2bc256 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
@@ -21,6 +21,7 @@
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.actions.BasicActionLookupValue;
 import com.google.devtools.build.lib.analysis.ConfiguredTarget;
+import com.google.devtools.build.lib.analysis.configuredtargets.RuleConfiguredTarget;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.nestedset.NestedSet;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
@@ -35,7 +36,8 @@
 /** A non-rule configured target in the context of a Skyframe graph. */
 @Immutable
 @ThreadSafe
-@AutoCodec
+// Reached via OutputFileConfiguredTarget.
+@AutoCodec(explicitlyAllowClass = RuleConfiguredTarget.class)
 @VisibleForTesting
 public final class NonRuleConfiguredTargetValue extends BasicActionLookupValue
     implements ConfiguredTargetValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
index ba19623..babdce6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
@@ -30,7 +30,7 @@
 import java.util.List;
 
 /** A Skyframe value representing a package. */
-@AutoCodec
+@AutoCodec(explicitlyAllowClass = Package.class)
 @Immutable
 @ThreadSafe
 public class PackageValue implements NotComparableSkyValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PostConfiguredTargetValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/PostConfiguredTargetValue.java
index 827bb8a..4e80bd6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PostConfiguredTargetValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PostConfiguredTargetValue.java
@@ -17,6 +17,7 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Interner;
 import com.google.devtools.build.lib.analysis.ConfiguredTarget;
+import com.google.devtools.build.lib.analysis.configuredtargets.RuleConfiguredTarget;
 import com.google.devtools.build.lib.concurrent.BlazeInterners;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
 import com.google.devtools.build.skyframe.AbstractSkyKey;
@@ -28,11 +29,12 @@
  * A post-processed ConfiguredTarget which is known to be transitively error-free from action
  * conflict issues.
  */
+@AutoCodec(explicitlyAllowClass = RuleConfiguredTarget.class)
 class PostConfiguredTargetValue implements SkyValue {
 
   private final ConfiguredTarget ct;
 
-  public PostConfiguredTargetValue(ConfiguredTarget ct) {
+  PostConfiguredTargetValue(ConfiguredTarget ct) {
     this.ct = Preconditions.checkNotNull(ct);
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
index 359e638..5edb986 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
@@ -32,7 +32,7 @@
 /** A configured target in the context of a Skyframe graph. */
 @Immutable
 @ThreadSafe
-@AutoCodec
+@AutoCodec(explicitlyAllowClass = RuleConfiguredTarget.class)
 @VisibleForTesting
 public final class RuleConfiguredTargetValue extends ActionLookupValue
     implements ConfiguredTargetValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/TargetCompletionValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/TargetCompletionValue.java
index dbbb6a7..c0d64c6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/TargetCompletionValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/TargetCompletionValue.java
@@ -24,9 +24,7 @@
 import java.util.Collection;
 import java.util.Set;
 
-/**
- * The value of a TargetCompletion. Currently this just stores a ConfiguredTarget.
- */
+/** The value of a TargetCompletion. Just a sentinel. */
 public class TargetCompletionValue implements SkyValue {
   @AutoCodec static final TargetCompletionValue INSTANCE = new TargetCompletionValue();
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
index dedbf2d..787b6c8 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
@@ -17,6 +17,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -27,7 +28,9 @@
 import com.google.protobuf.CodedOutputStream;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import javax.annotation.CheckReturnValue;
 import javax.annotation.Nullable;
 
@@ -42,6 +45,7 @@
   private final ObjectCodecRegistry registry;
   private final ImmutableMap<Class<?>, Object> dependencies;
   @Nullable private final Memoizer.Serializer serializer;
+  private final Set<Class<?>> explicitlyAllowedClasses;
   /** Initialized lazily. */
   @Nullable private List<ListenableFuture<Void>> futuresToBlockWritingOn;
 
@@ -56,6 +60,7 @@
     this.dependencies = dependencies;
     this.serializer = serializer;
     this.allowFuturesToBlockWritingOn = allowFuturesToBlockWritingOn;
+    explicitlyAllowedClasses = serializer != null ? new HashSet<>() : ImmutableSet.of();
   }
 
   @VisibleForTesting
@@ -175,6 +180,44 @@
         : null;
   }
 
+  /**
+   * Asserts during serialization that the encoded class of this codec has been explicitly
+   * whitelisted for serialization (using {@link #addExplicitlyAllowedClass}). Codecs for objects
+   * that are expensive to serialize and that should only be encountered in a limited number of
+   * types of {@link com.google.devtools.build.skyframe.SkyValue}s should call this method to check
+   * that the object is being serialized as part of an expected {@link
+   * com.google.devtools.build.skyframe.SkyValue}, like {@link
+   * com.google.devtools.build.lib.packages.Package} inside {@link
+   * com.google.devtools.build.lib.skyframe.PackageValue}.
+   */
+  public void checkClassExplicitlyAllowed(Class<?> allowedClass) {
+    Preconditions.checkNotNull(
+        serializer, "Cannot check explicitly allowed class %s without memoization", allowedClass);
+    Preconditions.checkState(explicitlyAllowedClasses.contains(allowedClass), allowedClass);
+  }
+
+  /**
+   * Adds an explicitly allowed class for this serialization context, which must be a memoizing
+   * context. Must be called by any codec that transitively serializes an object whose codec calls
+   * {@link #checkClassExplicitlyAllowed}.
+   *
+   * <p>Normally called by codecs for {@link com.google.devtools.build.skyframe.SkyValue} subclasses
+   * that know they may encounter an object that is expensive to serialize, like {@link
+   * com.google.devtools.build.lib.skyframe.PackageValue} and {@link
+   * com.google.devtools.build.lib.packages.Package} or {@link
+   * com.google.devtools.build.lib.skyframe.ConfiguredTargetValue} and {@link
+   * com.google.devtools.build.lib.analysis.configuredtargets.RuleConfiguredTarget}.
+   *
+   * <p>In case of an unexpected failure from {@link #checkClassExplicitlyAllowed}, it should first
+   * be determined if the inclusion of the expensive object is legitimate, before it is whitelisted
+   * using this method.
+   */
+  public void addExplicitlyAllowedClass(Class<?> allowedClass) {
+    Preconditions.checkNotNull(
+        serializer, "Cannot add explicitly allowed class %s without memoization", allowedClass);
+    explicitlyAllowedClasses.add(allowedClass);
+  }
+
   private boolean writeNullOrConstant(@Nullable Object object, CodedOutputStream codedOut)
       throws IOException {
     if (object == null) {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
index 302f1e9..840e7df 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
@@ -90,6 +90,18 @@
   Strategy strategy() default Strategy.INSTANTIATOR;
 
   /**
+   * Checks whether or not this class is allowed to be serialized. See {@link
+   * com.google.devtools.build.lib.skyframe.serialization.SerializationContext#checkClassExplicitlyAllowed}.
+   */
+  boolean checkClassExplicitlyAllowed() default false;
+
+  /**
+   * Adds an explicitly allowed class for this serialization session. See {@link
+   * com.google.devtools.build.lib.skyframe.serialization.SerializationContext#addExplicitlyAllowedClass}.
+   */
+  Class<?>[] explicitlyAllowClass() default {};
+
+  /**
    * Signals that the annotated element is only visible for use by serialization. It should not be
    * used by other callers.
    *
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
index 395519d..cb38568 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
@@ -105,13 +105,13 @@
         TypeSpec.Builder codecClassBuilder;
         switch (annotation.strategy()) {
           case INSTANTIATOR:
-            codecClassBuilder = buildClassWithInstantiatorStrategy(encodedType);
+            codecClassBuilder = buildClassWithInstantiatorStrategy(encodedType, annotation);
             break;
           case PUBLIC_FIELDS:
-            codecClassBuilder = buildClassWithPublicFieldsStrategy(encodedType);
+            codecClassBuilder = buildClassWithPublicFieldsStrategy(encodedType, annotation);
             break;
           case AUTO_VALUE_BUILDER:
-            codecClassBuilder = buildClassWithAutoValueBuilderStrategy(encodedType);
+            codecClassBuilder = buildClassWithAutoValueBuilderStrategy(encodedType, annotation);
             break;
           default:
             throw new IllegalArgumentException("Unknown strategy: " + annotation.strategy());
@@ -171,7 +171,8 @@
         .build();
   }
 
-  private TypeSpec.Builder buildClassWithInstantiatorStrategy(TypeElement encodedType) {
+  private TypeSpec.Builder buildClassWithInstantiatorStrategy(
+      TypeElement encodedType, AutoCodec annotation) {
     ExecutableElement constructor = selectInstantiator(encodedType);
     List<? extends VariableElement> fields = constructor.getParameters();
 
@@ -180,10 +181,11 @@
 
     if (encodedType.getAnnotation(AutoValue.class) == null) {
       initializeUnsafeOffsets(codecClassBuilder, encodedType, fields);
-      codecClassBuilder.addMethod(buildSerializeMethodWithInstantiator(encodedType, fields));
+      codecClassBuilder.addMethod(
+          buildSerializeMethodWithInstantiator(encodedType, fields, annotation));
     } else {
       codecClassBuilder.addMethod(
-          buildSerializeMethodWithInstantiatorForAutoValue(encodedType, fields));
+          buildSerializeMethodWithInstantiatorForAutoValue(encodedType, fields, annotation));
     }
 
     MethodSpec.Builder deserializeBuilder =
@@ -195,7 +197,8 @@
     return codecClassBuilder;
   }
 
-  private TypeSpec.Builder buildClassWithAutoValueBuilderStrategy(TypeElement encodedType) {
+  private TypeSpec.Builder buildClassWithAutoValueBuilderStrategy(
+      TypeElement encodedType, AutoCodec annotation) {
     TypeElement builderType = findBuilderType(encodedType);
     List<ExecutableElement> getters = findGettersFromType(encodedType, builderType);
     ExecutableElement builderCreationMethod = findBuilderCreationMethod(encodedType, builderType);
@@ -203,7 +206,7 @@
     TypeSpec.Builder codecClassBuilder =
         AutoCodecUtil.initializeCodecClassBuilder(encodedType, env);
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, annotation, env);
     for (ExecutableElement getter : getters) {
       marshallers.writeSerializationCode(
           new Marshaller.Context(
@@ -514,9 +517,9 @@
   }
 
   private MethodSpec buildSerializeMethodWithInstantiator(
-      TypeElement encodedType, List<? extends VariableElement> fields) {
+      TypeElement encodedType, List<? extends VariableElement> fields, AutoCodec annotation) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, annotation, env);
     for (VariableElement parameter : fields) {
       Optional<FieldValueAndClass> hasField =
           getFieldByNameRecursive(encodedType, parameter.getSimpleName().toString());
@@ -619,16 +622,17 @@
   }
 
   private MethodSpec buildSerializeMethodWithInstantiatorForAutoValue(
-      TypeElement encodedType, List<? extends VariableElement> fields) {
+      TypeElement encodedType, List<? extends VariableElement> fields, AutoCodec annotation) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, annotation, env);
     for (VariableElement parameter : fields) {
       addSerializeParameterWithGetter(encodedType, parameter, serializeBuilder);
     }
     return serializeBuilder.build();
   }
 
-  private TypeSpec.Builder buildClassWithPublicFieldsStrategy(TypeElement encodedType) {
+  private TypeSpec.Builder buildClassWithPublicFieldsStrategy(
+      TypeElement encodedType, AutoCodec annotation) {
     TypeSpec.Builder codecClassBuilder =
         AutoCodecUtil.initializeCodecClassBuilder(encodedType, env);
     ImmutableList<? extends VariableElement> publicFields =
@@ -636,7 +640,8 @@
             .stream()
             .filter(this::isPublicField)
             .collect(toImmutableList());
-    codecClassBuilder.addMethod(buildSerializeMethodWithPublicFields(encodedType, publicFields));
+    codecClassBuilder.addMethod(
+        buildSerializeMethodWithPublicFields(encodedType, publicFields, annotation));
     MethodSpec.Builder deserializeBuilder =
         AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, env);
     buildDeserializeBody(deserializeBuilder, publicFields);
@@ -654,9 +659,9 @@
   }
 
   private MethodSpec buildSerializeMethodWithPublicFields(
-      TypeElement encodedType, List<? extends VariableElement> fields) {
+      TypeElement encodedType, List<? extends VariableElement> fields, AutoCodec annotation) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, annotation, env);
     for (VariableElement parameter : fields) {
       String paramAccessor = "input." + parameter.getSimpleName();
       marshallers.writeSerializationCode(
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
index 34925ee..f546dd8 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
@@ -27,11 +27,14 @@
 import com.squareup.javapoet.TypeName;
 import com.squareup.javapoet.TypeSpec;
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
 import java.util.stream.Collectors;
 import javax.annotation.processing.ProcessingEnvironment;
 import javax.lang.model.element.Element;
 import javax.lang.model.element.Modifier;
 import javax.lang.model.element.TypeElement;
+import javax.lang.model.type.MirroredTypesException;
 import javax.lang.model.type.TypeMirror;
 
 /** Static utilities for AutoCodec processors. */
@@ -71,7 +74,7 @@
    * @param encodedType type being serialized
    */
   static MethodSpec.Builder initializeSerializeMethodBuilder(
-      TypeElement encodedType, ProcessingEnvironment env) {
+      TypeElement encodedType, AutoCodec annotation, ProcessingEnvironment env) {
     MethodSpec.Builder builder =
         MethodSpec.methodBuilder("serialize")
             .addModifiers(Modifier.PUBLIC)
@@ -82,6 +85,21 @@
             .addParameter(SerializationContext.class, "context")
             .addParameter(TypeName.get(env.getTypeUtils().erasure(encodedType.asType())), "input")
             .addParameter(CodedOutputStream.class, "codedOut");
+    if (annotation.checkClassExplicitlyAllowed()) {
+      builder.addStatement("context.checkClassExplicitlyAllowed(getEncodedClass())");
+    }
+    List<? extends TypeMirror> explicitlyAllowedClasses;
+    try {
+      explicitlyAllowedClasses =
+          Arrays.stream(annotation.explicitlyAllowClass())
+              .map((clazz) -> getType(clazz, env))
+              .collect(Collectors.toList());
+    } catch (MirroredTypesException e) {
+      explicitlyAllowedClasses = e.getTypeMirrors();
+    }
+    for (TypeMirror explicitlyAllowedClass : explicitlyAllowedClasses) {
+      builder.addStatement("context.addExplicitlyAllowedClass($T.class)", explicitlyAllowedClass);
+    }
     return builder;
   }