Minor memoizing serialization context cleanups. * For MEMOIZE_BEFORE, uses the size of the memoization table to determine the tag value instead of adding them to the serialized representation. * Adds a constant NO_VALUE instead of using -1 everywhere. * Avoids double-lookup in MemoizingSerializationContext.memoize. * Factors out a helper method to select a serialization memo table. PiperOrigin-RevId: 628045013 Change-Id: Ib1510c616997ec6228f8629f48898b607d2d643d
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java index 0f3f1d7..563f0d6 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
@@ -17,7 +17,6 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableClassToInstanceMap; import com.google.common.util.concurrent.ListenableFuture; import com.google.devtools.build.lib.skyframe.serialization.DeferredObjectCodec.DeferredValue; @@ -35,6 +34,12 @@ * MemoizingSerializationContext} for the protocol description. */ abstract class MemoizingDeserializationContext extends DeserializationContext { + /** + * A placeholder that keeps the size of {@link #memoTable} consistent with the numbering of its + * contents. + */ + private static final PlaceholderValue INITIAL_VALUE_PLACEHOLDER = new PlaceholderValue(); + private final Int2ObjectOpenHashMap<Object> memoTable = new Int2ObjectOpenHashMap<>(); private int tagForMemoizedBefore = -1; private final Deque<Object> memoizedBeforeStackForSanityChecking = new ArrayDeque<>(); @@ -72,23 +77,28 @@ @Override public final void registerInitialValue(Object initialValue) { - Preconditions.checkState( - tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue); + checkState(tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue); int tag = tagForMemoizedBefore; tagForMemoizedBefore = -1; - memoize(tag, initialValue); + // Replaces the INITIAL_VALUE_PLACEHOLDER with the actual initial value. + checkState(memoTable.put(tag, initialValue) == INITIAL_VALUE_PLACEHOLDER); memoizedBeforeStackForSanityChecking.addLast(initialValue); } @Override final Object getMemoizedBackReference(int memoIndex) { - return checkNotNull(memoTable.get(memoIndex), memoIndex); + Object value = checkNotNull(memoTable.get(memoIndex), memoIndex); + checkState( + value != INITIAL_VALUE_PLACEHOLDER, + "Backreference prior to registerInitialValue: %s", + memoIndex); + return value; } @Override final Object deserializeAndMaybeMemoize(ObjectCodec<?> codec, CodedInputStream codedIn) throws SerializationException, IOException { - Preconditions.checkState( + checkState( tagForMemoizedBefore == -1, "non-null memoized-before tag %s (%s)", tagForMemoizedBefore, @@ -123,7 +133,13 @@ */ private final Object deserializeMemoBeforeContent(ObjectCodec<?> codec, CodedInputStream codedIn) throws SerializationException, IOException { - int tag = codedIn.readInt32(); + int tag = memoTable.size(); + // During serialization, the top-level object is the first object to be memoized regardless of + // the codec implementation. During deserialization, the top-level object only becomes + // available after `registerInitialValue` is called and some codecs may perform deserialization + // operations prior to `registerInitialValue`. To keep the tags in sync with the size of + // the `memoTable`, adds a placeholder for the top-level object. + memoTable.put(tag, INITIAL_VALUE_PLACEHOLDER); this.tagForMemoizedBefore = tag; // `codec` is never a `DeferredObjectCodec` because those are `MEMOIZE_AFTER` so this is always // the deserialized value instance and never a `DeferredValue`. @@ -231,4 +247,8 @@ return value; } } + + private static final class PlaceholderValue { + private PlaceholderValue() {} + } }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java index 202a0d1..c178748 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
@@ -13,9 +13,11 @@ // limitations under the License. package com.google.devtools.build.lib.skyframe.serialization; +import static com.google.common.base.Preconditions.checkState; + import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableClassToInstanceMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.protobuf.ByteString; import com.google.protobuf.CodedOutputStream; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; @@ -133,6 +135,8 @@ // SerializationException. This requires just a little extra memo tracking for the MEMOIZE_AFTER // case. abstract class MemoizingSerializationContext extends SerializationContext { + private static final int NO_VALUE = -1; + private final Reference2IntOpenHashMap<Object> table = new Reference2IntOpenHashMap<>(); /** Table for types memoized using values equality, currently only {@link String}. */ @@ -149,8 +153,8 @@ MemoizingSerializationContext( ObjectCodecRegistry codecRegistry, ImmutableClassToInstanceMap<Object> dependencies) { super(codecRegistry, dependencies); - table.defaultReturnValue(-1); - valuesTable.defaultReturnValue(-1); + table.defaultReturnValue(NO_VALUE); + valuesTable.defaultReturnValue(NO_VALUE); } static byte[] serializeToBytes( @@ -201,8 +205,8 @@ switch (codec.getStrategy()) { case MEMOIZE_BEFORE: { - int id = memoize(obj); - codedOut.writeInt32NoTag(id); + // Deserialization determines the value of this tag based on the size of its memo table. + memoize(obj); codec.serialize(this, obj, codedOut); break; } @@ -213,7 +217,7 @@ // cycle, then there's now a memo entry for the parent. Don't overwrite it with a new // id. int cylicallyCreatedId = getMemoizedIndex(obj); - int id = (cylicallyCreatedId != -1) ? cylicallyCreatedId : memoize(obj); + int id = (cylicallyCreatedId != NO_VALUE) ? cylicallyCreatedId : memoize(obj); codedOut.writeInt32NoTag(id); break; } @@ -224,7 +228,7 @@ final boolean writeBackReferenceIfMemoized(Object obj, CodedOutputStream codedOut) throws IOException { int memoizedIndex = getMemoizedIndex(obj); - if (memoizedIndex == -1) { + if (memoizedIndex == NO_VALUE) { return false; } // Subtracts 1 so it will be negative and not collide with null. @@ -237,12 +241,12 @@ return true; } - /** If the value is already memoized, return its on-the-wire id; otherwise returns {@code -1}. */ + /** + * If the value is already memoized, return its on-the-wire id; otherwise returns {@link + * #NO_VALUE}. + */ private int getMemoizedIndex(Object value) { - if (value instanceof String) { - return valuesTable.getInt(value); - } - return table.getInt(value); + return isValueType(value) ? valuesTable.getInt(value) : table.getInt(value); } /** @@ -250,19 +254,20 @@ * * <p>{@code value} must not already be present. */ + @CanIgnoreReturnValue // may be called for side effect private int memoize(Object value) { - Preconditions.checkArgument( - getMemoizedIndex(value) == -1, "Tried to memoize object '%s' multiple times", value); // Ids count sequentially from 0. int newId = table.size() + valuesTable.size(); - if (value instanceof String) { - valuesTable.put(value, newId); - } else { - table.put(value, newId); - } + int maybePrevious = + isValueType(value) ? valuesTable.put(value, newId) : table.put(value, newId); + checkState(maybePrevious == NO_VALUE, "Memoized object '%s' multiple times", value); return newId; } + private boolean isValueType(Object value) { + return value instanceof String; + } + private static void serializeToStream( ObjectCodecRegistry codecRegistry, ImmutableClassToInstanceMap<Object> dependencies,