Minor memoizing serialization context cleanups.

* For MEMOIZE_BEFORE, uses the size of the memoization table to determine the
  tag value instead of adding them to the serialized representation.
* Adds a constant NO_VALUE instead of using -1 everywhere.
* Avoids double-lookup in MemoizingSerializationContext.memoize.
* Factors out a helper method to select a serialization memo table.

PiperOrigin-RevId: 628045013
Change-Id: Ib1510c616997ec6228f8629f48898b607d2d643d
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
index 0f3f1d7..563f0d6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
@@ -17,7 +17,6 @@
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableClassToInstanceMap;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.devtools.build.lib.skyframe.serialization.DeferredObjectCodec.DeferredValue;
@@ -35,6 +34,12 @@
  * MemoizingSerializationContext} for the protocol description.
  */
 abstract class MemoizingDeserializationContext extends DeserializationContext {
+  /**
+   * A placeholder that keeps the size of {@link #memoTable} consistent with the numbering of its
+   * contents.
+   */
+  private static final PlaceholderValue INITIAL_VALUE_PLACEHOLDER = new PlaceholderValue();
+
   private final Int2ObjectOpenHashMap<Object> memoTable = new Int2ObjectOpenHashMap<>();
   private int tagForMemoizedBefore = -1;
   private final Deque<Object> memoizedBeforeStackForSanityChecking = new ArrayDeque<>();
@@ -72,23 +77,28 @@
 
   @Override
   public final void registerInitialValue(Object initialValue) {
-    Preconditions.checkState(
-        tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue);
+    checkState(tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue);
     int tag = tagForMemoizedBefore;
     tagForMemoizedBefore = -1;
-    memoize(tag, initialValue);
+    // Replaces the INITIAL_VALUE_PLACEHOLDER with the actual initial value.
+    checkState(memoTable.put(tag, initialValue) == INITIAL_VALUE_PLACEHOLDER);
     memoizedBeforeStackForSanityChecking.addLast(initialValue);
   }
 
   @Override
   final Object getMemoizedBackReference(int memoIndex) {
-    return checkNotNull(memoTable.get(memoIndex), memoIndex);
+    Object value = checkNotNull(memoTable.get(memoIndex), memoIndex);
+    checkState(
+        value != INITIAL_VALUE_PLACEHOLDER,
+        "Backreference prior to registerInitialValue: %s",
+        memoIndex);
+    return value;
   }
 
   @Override
   final Object deserializeAndMaybeMemoize(ObjectCodec<?> codec, CodedInputStream codedIn)
       throws SerializationException, IOException {
-    Preconditions.checkState(
+    checkState(
         tagForMemoizedBefore == -1,
         "non-null memoized-before tag %s (%s)",
         tagForMemoizedBefore,
@@ -123,7 +133,13 @@
    */
   private final Object deserializeMemoBeforeContent(ObjectCodec<?> codec, CodedInputStream codedIn)
       throws SerializationException, IOException {
-    int tag = codedIn.readInt32();
+    int tag = memoTable.size();
+    // During serialization, the top-level object is the first object to be memoized regardless of
+    // the codec implementation. During deserialization, the top-level object only becomes
+    // available after `registerInitialValue` is called and some codecs may perform deserialization
+    // operations prior to `registerInitialValue`. To keep the tags in sync with the size of
+    // the `memoTable`, adds a placeholder for the top-level object.
+    memoTable.put(tag, INITIAL_VALUE_PLACEHOLDER);
     this.tagForMemoizedBefore = tag;
     // `codec` is never a `DeferredObjectCodec` because those are `MEMOIZE_AFTER` so this is always
     // the deserialized value instance and never a `DeferredValue`.
@@ -231,4 +247,8 @@
       return value;
     }
   }
+
+  private static final class PlaceholderValue {
+    private PlaceholderValue() {}
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
index 202a0d1..c178748 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
@@ -13,9 +13,11 @@
 // limitations under the License.
 package com.google.devtools.build.lib.skyframe.serialization;
 
+import static com.google.common.base.Preconditions.checkState;
+
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableClassToInstanceMap;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.CodedOutputStream;
 import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
@@ -133,6 +135,8 @@
 // SerializationException. This requires just a little extra memo tracking for the MEMOIZE_AFTER
 // case.
 abstract class MemoizingSerializationContext extends SerializationContext {
+  private static final int NO_VALUE = -1;
+
   private final Reference2IntOpenHashMap<Object> table = new Reference2IntOpenHashMap<>();
 
   /** Table for types memoized using values equality, currently only {@link String}. */
@@ -149,8 +153,8 @@
   MemoizingSerializationContext(
       ObjectCodecRegistry codecRegistry, ImmutableClassToInstanceMap<Object> dependencies) {
     super(codecRegistry, dependencies);
-    table.defaultReturnValue(-1);
-    valuesTable.defaultReturnValue(-1);
+    table.defaultReturnValue(NO_VALUE);
+    valuesTable.defaultReturnValue(NO_VALUE);
   }
 
   static byte[] serializeToBytes(
@@ -201,8 +205,8 @@
     switch (codec.getStrategy()) {
       case MEMOIZE_BEFORE:
         {
-          int id = memoize(obj);
-          codedOut.writeInt32NoTag(id);
+          // Deserialization determines the value of this tag based on the size of its memo table.
+          memoize(obj);
           codec.serialize(this, obj, codedOut);
           break;
         }
@@ -213,7 +217,7 @@
           // cycle, then there's now a memo entry for the parent. Don't overwrite it with a new
           // id.
           int cylicallyCreatedId = getMemoizedIndex(obj);
-          int id = (cylicallyCreatedId != -1) ? cylicallyCreatedId : memoize(obj);
+          int id = (cylicallyCreatedId != NO_VALUE) ? cylicallyCreatedId : memoize(obj);
           codedOut.writeInt32NoTag(id);
           break;
         }
@@ -224,7 +228,7 @@
   final boolean writeBackReferenceIfMemoized(Object obj, CodedOutputStream codedOut)
       throws IOException {
     int memoizedIndex = getMemoizedIndex(obj);
-    if (memoizedIndex == -1) {
+    if (memoizedIndex == NO_VALUE) {
       return false;
     }
     // Subtracts 1 so it will be negative and not collide with null.
@@ -237,12 +241,12 @@
     return true;
   }
 
-  /** If the value is already memoized, return its on-the-wire id; otherwise returns {@code -1}. */
+  /**
+   * If the value is already memoized, return its on-the-wire id; otherwise returns {@link
+   * #NO_VALUE}.
+   */
   private int getMemoizedIndex(Object value) {
-    if (value instanceof String) {
-      return valuesTable.getInt(value);
-    }
-    return table.getInt(value);
+    return isValueType(value) ? valuesTable.getInt(value) : table.getInt(value);
   }
 
   /**
@@ -250,19 +254,20 @@
    *
    * <p>{@code value} must not already be present.
    */
+  @CanIgnoreReturnValue // may be called for side effect
   private int memoize(Object value) {
-    Preconditions.checkArgument(
-        getMemoizedIndex(value) == -1, "Tried to memoize object '%s' multiple times", value);
     // Ids count sequentially from 0.
     int newId = table.size() + valuesTable.size();
-    if (value instanceof String) {
-      valuesTable.put(value, newId);
-    } else {
-      table.put(value, newId);
-    }
+    int maybePrevious =
+        isValueType(value) ? valuesTable.put(value, newId) : table.put(value, newId);
+    checkState(maybePrevious == NO_VALUE, "Memoized object '%s' multiple times", value);
     return newId;
   }
 
+  private boolean isValueType(Object value) {
+    return value instanceof String;
+  }
+
   private static void serializeToStream(
       ObjectCodecRegistry codecRegistry,
       ImmutableClassToInstanceMap<Object> dependencies,