Moves the decision to enable memoization from codecs to the top-level invocation.

Also, makes it benign to registerInitialValue when memoization is disabled.

PiperOrigin-RevId: 191338253
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/AspectValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/AspectValue.java
index 2c88115..d58f350 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/AspectValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/AspectValue.java
@@ -44,7 +44,7 @@
 import javax.annotation.Nullable;
 
 /** An aspect in the context of the Skyframe graph. */
-@AutoCodec(memoization = AutoCodec.Memoization.START_MEMOIZING)
+@AutoCodec
 public final class AspectValue extends BasicActionLookupValue {
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
index f15f7d0..1d9d1a6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/NonRuleConfiguredTargetValue.java
@@ -35,7 +35,7 @@
 /** A non-rule configured target in the context of a Skyframe graph. */
 @Immutable
 @ThreadSafe
-@AutoCodec(memoization = AutoCodec.Memoization.START_MEMOIZING)
+@AutoCodec
 @VisibleForTesting
 public final class NonRuleConfiguredTargetValue extends BasicActionLookupValue
     implements ConfiguredTargetValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
index 565e4f1..ba19623 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageValue.java
@@ -30,7 +30,7 @@
 import java.util.List;
 
 /** A Skyframe value representing a package. */
-@AutoCodec(memoization = AutoCodec.Memoization.START_MEMOIZING)
+@AutoCodec
 @Immutable
 @ThreadSafe
 public class PackageValue implements NotComparableSkyValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
index b48f247..359e638 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/RuleConfiguredTargetValue.java
@@ -26,14 +26,13 @@
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 import com.google.devtools.build.lib.packages.Package;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
-import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec.Memoization;
 import java.util.ArrayList;
 import javax.annotation.Nullable;
 
 /** A configured target in the context of a Skyframe graph. */
 @Immutable
 @ThreadSafe
-@AutoCodec(memoization = Memoization.START_MEMOIZING)
+@AutoCodec
 @VisibleForTesting
 public final class RuleConfiguredTargetValue extends ActionLookupValue
     implements ConfiguredTargetValue {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
index 86aea04..e9bacf3 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
@@ -19,7 +19,6 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.skyframe.serialization.Memoizer.Deserializer;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecRegistry.CodecDescriptor;
-import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs.MemoizationPermission;
 import com.google.protobuf.CodedInputStream;
 import java.io.IOException;
 import javax.annotation.CheckReturnValue;
@@ -33,24 +32,21 @@
 public class DeserializationContext {
   private final ObjectCodecRegistry registry;
   private final ImmutableMap<Class<?>, Object> dependencies;
-  private final MemoizationPermission memoizationPermission;
   private final Memoizer.Deserializer deserializer;
 
   private DeserializationContext(
       ObjectCodecRegistry registry,
       ImmutableMap<Class<?>, Object> dependencies,
-      MemoizationPermission memoizationPermission,
       Deserializer deserializer) {
     this.registry = registry;
     this.dependencies = dependencies;
-    this.memoizationPermission = memoizationPermission;
     this.deserializer = deserializer;
   }
 
   @VisibleForTesting
   public DeserializationContext(
       ObjectCodecRegistry registry, ImmutableMap<Class<?>, Object> dependencies) {
-    this(registry, dependencies, MemoizationPermission.ALLOWED, /*deserializer=*/ null);
+    this(registry, dependencies, /*deserializer=*/ null);
   }
 
   @VisibleForTesting
@@ -58,42 +54,6 @@
     this(AutoRegistry.get(), dependencies);
   }
 
-  DeserializationContext disableMemoization() {
-    Preconditions.checkState(
-        memoizationPermission == MemoizationPermission.ALLOWED, "memoization already disabled");
-    Preconditions.checkState(deserializer == null, "deserializer already present");
-    return new DeserializationContext(
-        registry, dependencies, MemoizationPermission.DISABLED, deserializer);
-  }
-
-  /**
-   * Returns a {@link DeserializationContext} that will memoize values it encounters (using
-   * reference equality), the inverse of the memoization performed by a {@link SerializationContext}
-   * returned by {@link SerializationContext#getMemoizingContext}. The context returned here should
-   * be used instead of the original: memoization may only occur when using the returned context.
-   *
-   * <p>This method is idempotent: calling it on an already memoizing context will return the same
-   * context.
-   */
-  @CheckReturnValue
-  public DeserializationContext getMemoizingContext() {
-    Preconditions.checkState(
-        memoizationPermission == MemoizationPermission.ALLOWED, "memoization disabled");
-    if (deserializer != null) {
-      return this;
-    }
-    return new DeserializationContext(
-        this.registry, this.dependencies, memoizationPermission, new Deserializer());
-  }
-
-  /**
-   * Register an initial value for the currently deserializing value, for use by child objects that
-   * may have references to it. Only for use during memoizing deserialization.
-   */
-  public <T> void registerInitialValue(T initialValue) {
-    deserializer.registerInitialValue(initialValue);
-  }
-
   // TODO(shahan): consider making codedIn a member of this class.
   @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"})
   public <T> T deserialize(CodedInputStream codedIn) throws IOException, SerializationException {
@@ -113,9 +73,42 @@
     }
   }
 
+  /**
+   * Register an initial value for the currently deserializing value, for use by child objects that
+   * may have references to it.
+   *
+   * <p>This is a noop when memoization is disabled.
+   */
+  public <T> void registerInitialValue(T initialValue) {
+    if (deserializer == null) {
+      return;
+    }
+    deserializer.registerInitialValue(initialValue);
+  }
+
   @SuppressWarnings("unchecked")
   public <T> T getDependency(Class<T> type) {
     Preconditions.checkNotNull(type);
     return (T) dependencies.get(type);
   }
+
+  /**
+   * Returns a {@link DeserializationContext} that will memoize values it encounters (using
+   * reference equality), the inverse of the memoization performed by a {@link SerializationContext}
+   * returned by {@link SerializationContext#getMemoizingContext}. The context returned here should
+   * be used instead of the original: memoization may only occur when using the returned context.
+   *
+   * <p>This method is idempotent: calling it on an already memoizing context will return the same
+   * context.
+   *
+   * <p><em>This is public for testing and {@link
+   * com.google.devtools.build.lib.packages.PackageSerializer} only.</em>
+   */
+  @CheckReturnValue
+  public DeserializationContext getMemoizingContext() {
+    if (deserializer != null) {
+      return this;
+    }
+    return new DeserializationContext(this.registry, this.dependencies, new Deserializer());
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
index 4128092..4ec874b 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
@@ -38,18 +38,11 @@
   private final Class<?> type;
   private final Constructor<?> constructor;
   private final ImmutableSortedMap<Field, Long> offsets;
-  private final ObjectCodec.MemoizationStrategy strategy;
 
   public DynamicCodec(Class<?> type) throws ReflectiveOperationException {
-    this(type, ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE);
-  }
-
-  public DynamicCodec(Class<?> type, ObjectCodec.MemoizationStrategy strategy)
-      throws ReflectiveOperationException {
     this.type = type;
     this.constructor = getConstructor(type);
     this.offsets = getOffsets(type);
-    this.strategy = strategy;
   }
 
   @Override
@@ -59,7 +52,7 @@
 
   @Override
   public MemoizationStrategy getStrategy() {
-    return strategy;
+    return ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE;
   }
 
   @Override
@@ -141,9 +134,7 @@
     } catch (ReflectiveOperationException e) {
       throw new SerializationException("Could not instantiate object of type: " + type, e);
     }
-    if (strategy.equals(ObjectCodec.MemoizationStrategy.MEMOIZE_BEFORE)) {
-      context.registerInitialValue(instance);
-    }
+    context.registerInitialValue(instance);
     for (Map.Entry<Field, Long> entry : offsets.entrySet()) {
       deserializeField(context, codedIn, instance, entry.getKey().getType(), entry.getValue());
     }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ObjectCodecs.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ObjectCodecs.java
index 70e4505..3674901 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ObjectCodecs.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ObjectCodecs.java
@@ -14,7 +14,6 @@
 
 package com.google.devtools.build.lib.skyframe.serialization;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableMap;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.CodedInputStream;
@@ -26,8 +25,6 @@
  * serving as a layer between the streaming-oriented {@link ObjectCodec} interface and users.
  */
 public class ObjectCodecs {
-  // TODO(shahan): when per-invocation state is needed, for example, memoization, these may
-  // need to be constructed each time.
   private final SerializationContext serializationContext;
   private final DeserializationContext deserializationContext;
 
@@ -42,66 +39,81 @@
   }
 
   public ByteString serialize(Object subject) throws SerializationException {
-    ByteString.Output resultOut = ByteString.newOutput();
-    CodedOutputStream codedOut = CodedOutputStream.newInstance(resultOut);
-    try {
-      serializationContext.serialize(subject, codedOut);
-      codedOut.flush();
-      return resultOut.toByteString();
-    } catch (IOException e) {
-      throw new SerializationException("Failed to serialize " + subject, e);
-    }
+    return serializeToByteString(subject, this::serialize);
   }
 
-  public void serialize(
-      Object subject, CodedOutputStream codedOut, MemoizationPermission memoizationPermission)
+  public void serialize(Object subject, CodedOutputStream codedOut) throws SerializationException {
+    serializeImpl(subject, codedOut, /*memoize=*/ false);
+  }
+
+  public ByteString serializeMemoized(Object subject) throws SerializationException {
+    return serializeToByteString(subject, this::serializeMemoized);
+  }
+
+  public void serializeMemoized(Object subject, CodedOutputStream codedOut)
       throws SerializationException {
-    SerializationContext context = serializationContext;
-    if (memoizationPermission == MemoizationPermission.DISABLED) {
-      context = context.disableMemoization();
-    }
-    try {
-      context.serialize(subject, codedOut);
-    } catch (IOException e) {
-      throw new SerializationException("Failed to serialize " + subject, e);
-    }
-  }
-
-  /**
-   * Controls whether memoization can occur for serialization/deserialization. Should be allowed
-   * unless bit-equivalence is needed.
-   */
-  public enum MemoizationPermission {
-    ALLOWED,
-    DISABLED
+    serializeImpl(subject, codedOut, /*memoize=*/ true);
   }
 
   public Object deserialize(ByteString data) throws SerializationException {
-    return deserialize(data.newCodedInput(), MemoizationPermission.ALLOWED);
+    return deserialize(data.newCodedInput());
   }
 
-  public Object deserialize(CodedInputStream codedIn, MemoizationPermission memoizationPermission)
+  public Object deserialize(CodedInputStream codedIn) throws SerializationException {
+    return deserializeImpl(codedIn, /*memoize=*/ false);
+  }
+
+  public Object deserializeMemoized(ByteString data) throws SerializationException {
+    return deserializeMemoized(data.newCodedInput());
+  }
+
+  public Object deserializeMemoized(CodedInputStream codedIn) throws SerializationException {
+    return deserializeImpl(codedIn, /*memoize=*/ true);
+  }
+
+  private void serializeImpl(Object subject, CodedOutputStream codedOut, boolean memoize)
       throws SerializationException {
-    // Allow access to buffer without copying (although this means buffer may be pinned in memory).
-    codedIn.enableAliasing(true);
-    DeserializationContext context = deserializationContext;
-    if (memoizationPermission == MemoizationPermission.DISABLED) {
-      context = context.disableMemoization();
-    }
     try {
-      return context.deserialize(codedIn);
+      if (memoize) {
+        serializationContext.getMemoizingContext().serialize(subject, codedOut);
+      } else {
+        serializationContext.serialize(subject, codedOut);
+      }
+    } catch (IOException e) {
+      throw new SerializationException("Failed to serialize " + subject, e);
+    }
+  }
+
+  private Object deserializeImpl(CodedInputStream codedIn, boolean memoize)
+      throws SerializationException {
+    // Allows access to buffer without copying (although this means buffer may be pinned in memory).
+    codedIn.enableAliasing(true);
+    try {
+      if (memoize) {
+        return deserializationContext.getMemoizingContext().deserialize(codedIn);
+      } else {
+        return deserializationContext.deserialize(codedIn);
+      }
     } catch (IOException e) {
       throw new SerializationException("Failed to deserialize data", e);
     }
   }
 
-  @VisibleForTesting
-  public SerializationContext getSerializationContextForTesting() {
-    return serializationContext;
+  @FunctionalInterface
+  private static interface SerializeCall {
+    void serialize(Object subject, CodedOutputStream codedOut) throws SerializationException;
   }
 
-  @VisibleForTesting
-  public DeserializationContext getDeserializationContextForTesting() {
-    return deserializationContext;
+  private static ByteString serializeToByteString(Object subject, SerializeCall wrapped)
+      throws SerializationException {
+    ByteString.Output resultOut = ByteString.newOutput();
+    CodedOutputStream codedOut = CodedOutputStream.newInstance(resultOut);
+    wrapped.serialize(subject, codedOut);
+    try {
+      codedOut.flush();
+      return resultOut.toByteString();
+    } catch (IOException e) {
+      throw new SerializationException("Failed to serialize " + subject, e);
+    }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
index 6e0d681..e06d6db 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
@@ -18,7 +18,6 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.skyframe.serialization.Memoizer.Serializer;
-import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs.MemoizationPermission;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException.NoCodecException;
 import com.google.protobuf.CodedOutputStream;
 import java.io.IOException;
@@ -34,26 +33,21 @@
 public class SerializationContext {
   private final ObjectCodecRegistry registry;
   private final ImmutableMap<Class<?>, Object> dependencies;
-  private final MemoizationPermission memoizationPermission;
   @Nullable private final Memoizer.Serializer serializer;
 
   private SerializationContext(
       ObjectCodecRegistry registry,
       ImmutableMap<Class<?>, Object> dependencies,
-      MemoizationPermission memoizationPermission,
       @Nullable Serializer serializer) {
     this.registry = registry;
     this.dependencies = dependencies;
     this.serializer = serializer;
-    this.memoizationPermission = memoizationPermission;
-    Preconditions.checkState(
-        serializer == null || memoizationPermission == MemoizationPermission.ALLOWED);
   }
 
   @VisibleForTesting
   public SerializationContext(
       ObjectCodecRegistry registry, ImmutableMap<Class<?>, Object> dependencies) {
-    this(registry, dependencies, MemoizationPermission.ALLOWED, /*serializer=*/ null);
+    this(registry, dependencies, /*serializer=*/ null);
   }
 
   @VisibleForTesting
@@ -61,12 +55,26 @@
     this(AutoRegistry.get(), dependencies);
   }
 
-  SerializationContext disableMemoization() {
-    Preconditions.checkState(
-        memoizationPermission == MemoizationPermission.ALLOWED, "memoization already disabled");
-    Preconditions.checkState(serializer == null, "serializer already present");
-    return new SerializationContext(
-        registry, dependencies, MemoizationPermission.DISABLED, serializer);
+  // TODO(shahan): consider making codedOut a member of this class.
+  public void serialize(Object object, CodedOutputStream codedOut)
+      throws IOException, SerializationException {
+    ObjectCodecRegistry.CodecDescriptor descriptor =
+        recordAndGetDescriptorIfNotConstantOrNull(object, codedOut);
+    if (descriptor != null) {
+      if (serializer == null) {
+        descriptor.serialize(this, object, codedOut);
+      } else {
+        @SuppressWarnings("unchecked")
+        ObjectCodec<Object> castCodec = (ObjectCodec<Object>) descriptor.getCodec();
+        serializer.serialize(this, object, castCodec, codedOut);
+      }
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  public <T> T getDependency(Class<T> type) {
+    Preconditions.checkNotNull(type);
+    return (T) dependencies.get(type);
   }
 
   /**
@@ -78,16 +86,16 @@
    *
    * <p>This method is idempotent: calling it on an already memoizing context will return the same
    * context.
+   *
+   * <p><em>This is public for testing and {@link
+   * com.google.devtools.build.lib.packages.PackageSerializer} only.</em>
    */
   @CheckReturnValue
   public SerializationContext getMemoizingContext() {
-    Preconditions.checkState(
-        memoizationPermission == MemoizationPermission.ALLOWED, "memoization disabled");
     if (serializer != null) {
       return this;
     }
-    return new SerializationContext(
-        this.registry, this.dependencies, memoizationPermission, new Memoizer.Serializer());
+    return new SerializationContext(this.registry, this.dependencies, new Memoizer.Serializer());
   }
 
   private boolean writeNullOrConstant(@Nullable Object object, CodedOutputStream codedOut)
@@ -114,26 +122,4 @@
     codedOut.writeSInt32NoTag(descriptor.getTag());
     return descriptor;
   }
-
-  // TODO(shahan): consider making codedOut a member of this class.
-  public void serialize(Object object, CodedOutputStream codedOut)
-      throws IOException, SerializationException {
-    ObjectCodecRegistry.CodecDescriptor descriptor =
-        recordAndGetDescriptorIfNotConstantOrNull(object, codedOut);
-    if (descriptor != null) {
-      if (serializer == null) {
-        descriptor.serialize(this, object, codedOut);
-      } else {
-        @SuppressWarnings("unchecked")
-        ObjectCodec<Object> castCodec = (ObjectCodec<Object>) descriptor.getCodec();
-        serializer.serialize(this, object, castCodec, codedOut);
-      }
-    }
-    }
-
-  @SuppressWarnings("unchecked")
-  public <T> T getDependency(Class<T> type) {
-    Preconditions.checkNotNull(type);
-    return (T) dependencies.get(type);
-  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
index 6901b92..302f1e9 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodec.java
@@ -35,7 +35,7 @@
  *
  * <p>If applied to a field (which must be static and final), the field is stored as a "constant"
  * allowing for trivial serialization of it as an integer tag (see {@code CodecScanner} and
- * {@code ObjectCodecRegistery}). In order to do that, a trivial associated "RegisteredSingleton"
+ * {@code ObjectCodecRegistry}). In order to do that, a trivial associated "RegisteredSingleton"
  * class is generated.
  */
 @Target({ElementType.TYPE, ElementType.FIELD})
@@ -89,20 +89,6 @@
 
   Strategy strategy() default Strategy.INSTANTIATOR;
 
-  /** Whether to start memoizing values below this codec. */
-  enum Memoization {
-    /** Do not start memoization, but also do not disable memoization if it is already happening. */
-    UNCHANGED,
-    /**
-     * Start memoizing. Memoization is assumed to always need a Skylark "Mutability" object. If this
-     * package does not have access to the {@link com.google.devtools.build.lib.syntax.Mutability}
-     * class, memoization cannot be started here.
-     */
-    START_MEMOIZING
-  }
-
-  Memoization memoization() default Memoization.UNCHANGED;
-
   /**
    * Signals that the annotated element is only visible for use by serialization. It should not be
    * used by other callers.
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
index fb945a9..395519d 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecProcessor.java
@@ -25,7 +25,6 @@
 import com.google.devtools.build.lib.skyframe.serialization.CodecScanningConstants;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
-import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec.Memoization;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationCodeGenerator.Marshaller;
 import com.squareup.javapoet.ClassName;
 import com.squareup.javapoet.FieldSpec;
@@ -104,16 +103,15 @@
       if (element instanceof TypeElement) {
         TypeElement encodedType = (TypeElement) element;
         TypeSpec.Builder codecClassBuilder;
-        boolean startMemoizing = annotation.memoization() == Memoization.START_MEMOIZING;
         switch (annotation.strategy()) {
           case INSTANTIATOR:
-            codecClassBuilder = buildClassWithInstantiatorStrategy(encodedType, startMemoizing);
+            codecClassBuilder = buildClassWithInstantiatorStrategy(encodedType);
             break;
           case PUBLIC_FIELDS:
-            codecClassBuilder = buildClassWithPublicFieldsStrategy(encodedType, startMemoizing);
+            codecClassBuilder = buildClassWithPublicFieldsStrategy(encodedType);
             break;
           case AUTO_VALUE_BUILDER:
-            codecClassBuilder = buildClassWithAutoValueBuilderStrategy(encodedType, startMemoizing);
+            codecClassBuilder = buildClassWithAutoValueBuilderStrategy(encodedType);
             break;
           default:
             throw new IllegalArgumentException("Unknown strategy: " + annotation.strategy());
@@ -173,8 +171,7 @@
         .build();
   }
 
-  private TypeSpec.Builder buildClassWithInstantiatorStrategy(
-      TypeElement encodedType, boolean startMemoizing) {
+  private TypeSpec.Builder buildClassWithInstantiatorStrategy(TypeElement encodedType) {
     ExecutableElement constructor = selectInstantiator(encodedType);
     List<? extends VariableElement> fields = constructor.getParameters();
 
@@ -183,15 +180,14 @@
 
     if (encodedType.getAnnotation(AutoValue.class) == null) {
       initializeUnsafeOffsets(codecClassBuilder, encodedType, fields);
-      codecClassBuilder.addMethod(
-          buildSerializeMethodWithInstantiator(encodedType, fields, startMemoizing));
+      codecClassBuilder.addMethod(buildSerializeMethodWithInstantiator(encodedType, fields));
     } else {
       codecClassBuilder.addMethod(
-          buildSerializeMethodWithInstantiatorForAutoValue(encodedType, fields, startMemoizing));
+          buildSerializeMethodWithInstantiatorForAutoValue(encodedType, fields));
     }
 
     MethodSpec.Builder deserializeBuilder =
-        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, env);
     buildDeserializeBody(deserializeBuilder, fields);
     addReturnNew(deserializeBuilder, encodedType, constructor, /*builderVar=*/ null, env);
     codecClassBuilder.addMethod(deserializeBuilder.build());
@@ -199,8 +195,7 @@
     return codecClassBuilder;
   }
 
-  private TypeSpec.Builder buildClassWithAutoValueBuilderStrategy(
-      TypeElement encodedType, boolean startMemoizing) {
+  private TypeSpec.Builder buildClassWithAutoValueBuilderStrategy(TypeElement encodedType) {
     TypeElement builderType = findBuilderType(encodedType);
     List<ExecutableElement> getters = findGettersFromType(encodedType, builderType);
     ExecutableElement builderCreationMethod = findBuilderCreationMethod(encodedType, builderType);
@@ -208,7 +203,7 @@
     TypeSpec.Builder codecClassBuilder =
         AutoCodecUtil.initializeCodecClassBuilder(encodedType, env);
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
     for (ExecutableElement getter : getters) {
       marshallers.writeSerializationCode(
           new Marshaller.Context(
@@ -218,7 +213,7 @@
     }
     codecClassBuilder.addMethod(serializeBuilder.build());
     MethodSpec.Builder deserializeBuilder =
-        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, env);
     String builderVarName =
         buildDeserializeBodyWithBuilder(
             encodedType, builderType, deserializeBuilder, getters, builderCreationMethod);
@@ -519,9 +514,9 @@
   }
 
   private MethodSpec buildSerializeMethodWithInstantiator(
-      TypeElement encodedType, List<? extends VariableElement> fields, boolean startMemoizing) {
+      TypeElement encodedType, List<? extends VariableElement> fields) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
     for (VariableElement parameter : fields) {
       Optional<FieldValueAndClass> hasField =
           getFieldByNameRecursive(encodedType, parameter.getSimpleName().toString());
@@ -624,17 +619,16 @@
   }
 
   private MethodSpec buildSerializeMethodWithInstantiatorForAutoValue(
-      TypeElement encodedType, List<? extends VariableElement> fields, boolean startMemoizing) {
+      TypeElement encodedType, List<? extends VariableElement> fields) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
     for (VariableElement parameter : fields) {
       addSerializeParameterWithGetter(encodedType, parameter, serializeBuilder);
     }
     return serializeBuilder.build();
   }
 
-  private TypeSpec.Builder buildClassWithPublicFieldsStrategy(
-      TypeElement encodedType, boolean startMemoizing) {
+  private TypeSpec.Builder buildClassWithPublicFieldsStrategy(TypeElement encodedType) {
     TypeSpec.Builder codecClassBuilder =
         AutoCodecUtil.initializeCodecClassBuilder(encodedType, env);
     ImmutableList<? extends VariableElement> publicFields =
@@ -642,10 +636,9 @@
             .stream()
             .filter(this::isPublicField)
             .collect(toImmutableList());
-    codecClassBuilder.addMethod(
-        buildSerializeMethodWithPublicFields(encodedType, publicFields, startMemoizing));
+    codecClassBuilder.addMethod(buildSerializeMethodWithPublicFields(encodedType, publicFields));
     MethodSpec.Builder deserializeBuilder =
-        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeDeserializeMethodBuilder(encodedType, env);
     buildDeserializeBody(deserializeBuilder, publicFields);
     addInstantiatePopulateFieldsAndReturn(deserializeBuilder, encodedType, publicFields);
     codecClassBuilder.addMethod(deserializeBuilder.build());
@@ -661,9 +654,9 @@
   }
 
   private MethodSpec buildSerializeMethodWithPublicFields(
-      TypeElement encodedType, List<? extends VariableElement> fields, boolean startMemoizing) {
+      TypeElement encodedType, List<? extends VariableElement> fields) {
     MethodSpec.Builder serializeBuilder =
-        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, startMemoizing, env);
+        AutoCodecUtil.initializeSerializeMethodBuilder(encodedType, env);
     for (VariableElement parameter : fields) {
       String paramAccessor = "input." + parameter.getSimpleName();
       marshallers.writeSerializationCode(
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
index 2ba6d9b..34925ee 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec/AutoCodecUtil.java
@@ -69,10 +69,9 @@
    * Initializes the deserialize method.
    *
    * @param encodedType type being serialized
-   * @param startMemoizing whether memoization should start in this method.
    */
   static MethodSpec.Builder initializeSerializeMethodBuilder(
-      TypeElement encodedType, boolean startMemoizing, ProcessingEnvironment env) {
+      TypeElement encodedType, ProcessingEnvironment env) {
     MethodSpec.Builder builder =
         MethodSpec.methodBuilder("serialize")
             .addModifiers(Modifier.PUBLIC)
@@ -83,9 +82,6 @@
             .addParameter(SerializationContext.class, "context")
             .addParameter(TypeName.get(env.getTypeUtils().erasure(encodedType.asType())), "input")
             .addParameter(CodedOutputStream.class, "codedOut");
-    if (startMemoizing) {
-      builder.addStatement("context = context.getMemoizingContext()");
-    }
     return builder;
   }
 
@@ -93,10 +89,9 @@
    * Initializes the deserialize method.
    *
    * @param encodedType type being serialized
-   * @param startMemoizing whether memoization should start in this method.
    */
   static MethodSpec.Builder initializeDeserializeMethodBuilder(
-      TypeElement encodedType, boolean startMemoizing, ProcessingEnvironment env) {
+      TypeElement encodedType, ProcessingEnvironment env) {
     MethodSpec.Builder builder =
         MethodSpec.methodBuilder("deserialize")
             .addModifiers(Modifier.PUBLIC)
@@ -106,9 +101,6 @@
             .addException(IOException.class)
             .addParameter(DeserializationContext.class, "context")
             .addParameter(CodedInputStream.class, "codedIn");
-    if (startMemoizing) {
-      builder.addStatement("context = context.getMemoizingContext()");
-    }
     return builder;
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils/SerializationTester.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils/SerializationTester.java
index 91926db..b45d443 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils/SerializationTester.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils/SerializationTester.java
@@ -22,13 +22,11 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.skyframe.serialization.AutoRegistry;
-import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecRegistry;
 import com.google.devtools.build.lib.skyframe.serialization.ObjectCodecs;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.protobuf.ByteString;
-import com.google.protobuf.CodedOutputStream;
 import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
 import java.util.ArrayList;
@@ -126,24 +124,20 @@
 
   private ByteString serialize(Object subject, ObjectCodecs codecs)
       throws SerializationException, IOException {
-    if (!memoize) {
+    if (memoize) {
+      return codecs.serializeMemoized(subject);
+    } else {
       return codecs.serialize(subject);
     }
-    ByteString.Output output = ByteString.newOutput();
-    CodedOutputStream codedOut = CodedOutputStream.newInstance(output);
-    codecs.getSerializationContextForTesting().getMemoizingContext().serialize(subject, codedOut);
-    codedOut.flush();
-    return output.toByteString();
   }
 
   private Object deserialize(ByteString serialized, ObjectCodecs codecs)
       throws SerializationException, IOException {
-    if (!memoize) {
+    if (memoize) {
+      return codecs.deserializeMemoized(serialized);
+    } else {
       return codecs.deserialize(serialized);
     }
-    DeserializationContext context =
-        codecs.getDeserializationContextForTesting().getMemoizingContext();
-    return context.deserialize(serialized.newCodedInput());
   }
 
   /** Runs serialization/deserialization tests. */
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/Environment.java b/src/main/java/com/google/devtools/build/lib/syntax/Environment.java
index 697c317..46253fa 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/Environment.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/Environment.java
@@ -25,7 +25,6 @@
 import com.google.devtools.build.lib.events.EventKind;
 import com.google.devtools.build.lib.events.Location;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
-import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec.Memoization;
 import com.google.devtools.build.lib.skylarkinterface.SkylarkValue;
 import com.google.devtools.build.lib.syntax.Mutability.Freezable;
 import com.google.devtools.build.lib.syntax.Mutability.MutabilityException;
@@ -478,7 +477,7 @@
   @Immutable
   // TODO(janakr,brandjon): Do Extensions actually have to start their own memoization? Or can we
   // have a node higher up in the hierarchy inject the mutability?
-  @AutoCodec(memoization = Memoization.START_MEMOIZING)
+  @AutoCodec
   public static final class Extension {
 
     private final ImmutableMap<String, Object> bindings;