Minor memoizing serialization context cleanups.
* For MEMOIZE_BEFORE, uses the size of the memoization table to determine the
tag value instead of adding them to the serialized representation.
* Adds a constant NO_VALUE instead of using -1 everywhere.
* Avoids double-lookup in MemoizingSerializationContext.memoize.
* Factors out a helper method to select a serialization memo table.
PiperOrigin-RevId: 628045013
Change-Id: Ib1510c616997ec6228f8629f48898b607d2d643d
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
index 0f3f1d7..563f0d6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingDeserializationContext.java
@@ -17,7 +17,6 @@
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableClassToInstanceMap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.devtools.build.lib.skyframe.serialization.DeferredObjectCodec.DeferredValue;
@@ -35,6 +34,12 @@
* MemoizingSerializationContext} for the protocol description.
*/
abstract class MemoizingDeserializationContext extends DeserializationContext {
+ /**
+ * A placeholder that keeps the size of {@link #memoTable} consistent with the numbering of its
+ * contents.
+ */
+ private static final PlaceholderValue INITIAL_VALUE_PLACEHOLDER = new PlaceholderValue();
+
private final Int2ObjectOpenHashMap<Object> memoTable = new Int2ObjectOpenHashMap<>();
private int tagForMemoizedBefore = -1;
private final Deque<Object> memoizedBeforeStackForSanityChecking = new ArrayDeque<>();
@@ -72,23 +77,28 @@
@Override
public final void registerInitialValue(Object initialValue) {
- Preconditions.checkState(
- tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue);
+ checkState(tagForMemoizedBefore != -1, "Not called with memoize before: %s", initialValue);
int tag = tagForMemoizedBefore;
tagForMemoizedBefore = -1;
- memoize(tag, initialValue);
+ // Replaces the INITIAL_VALUE_PLACEHOLDER with the actual initial value.
+ checkState(memoTable.put(tag, initialValue) == INITIAL_VALUE_PLACEHOLDER);
memoizedBeforeStackForSanityChecking.addLast(initialValue);
}
@Override
final Object getMemoizedBackReference(int memoIndex) {
- return checkNotNull(memoTable.get(memoIndex), memoIndex);
+ Object value = checkNotNull(memoTable.get(memoIndex), memoIndex);
+ checkState(
+ value != INITIAL_VALUE_PLACEHOLDER,
+ "Backreference prior to registerInitialValue: %s",
+ memoIndex);
+ return value;
}
@Override
final Object deserializeAndMaybeMemoize(ObjectCodec<?> codec, CodedInputStream codedIn)
throws SerializationException, IOException {
- Preconditions.checkState(
+ checkState(
tagForMemoizedBefore == -1,
"non-null memoized-before tag %s (%s)",
tagForMemoizedBefore,
@@ -123,7 +133,13 @@
*/
private final Object deserializeMemoBeforeContent(ObjectCodec<?> codec, CodedInputStream codedIn)
throws SerializationException, IOException {
- int tag = codedIn.readInt32();
+ int tag = memoTable.size();
+ // During serialization, the top-level object is the first object to be memoized regardless of
+ // the codec implementation. During deserialization, the top-level object only becomes
+ // available after `registerInitialValue` is called and some codecs may perform deserialization
+ // operations prior to `registerInitialValue`. To keep the tags in sync with the size of
+ // the `memoTable`, adds a placeholder for the top-level object.
+ memoTable.put(tag, INITIAL_VALUE_PLACEHOLDER);
this.tagForMemoizedBefore = tag;
// `codec` is never a `DeferredObjectCodec` because those are `MEMOIZE_AFTER` so this is always
// the deserialized value instance and never a `DeferredValue`.
@@ -231,4 +247,8 @@
return value;
}
}
+
+ private static final class PlaceholderValue {
+ private PlaceholderValue() {}
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
index 202a0d1..c178748 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/MemoizingSerializationContext.java
@@ -13,9 +13,11 @@
// limitations under the License.
package com.google.devtools.build.lib.skyframe.serialization;
+import static com.google.common.base.Preconditions.checkState;
+
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableClassToInstanceMap;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedOutputStream;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
@@ -133,6 +135,8 @@
// SerializationException. This requires just a little extra memo tracking for the MEMOIZE_AFTER
// case.
abstract class MemoizingSerializationContext extends SerializationContext {
+ private static final int NO_VALUE = -1;
+
private final Reference2IntOpenHashMap<Object> table = new Reference2IntOpenHashMap<>();
/** Table for types memoized using values equality, currently only {@link String}. */
@@ -149,8 +153,8 @@
MemoizingSerializationContext(
ObjectCodecRegistry codecRegistry, ImmutableClassToInstanceMap<Object> dependencies) {
super(codecRegistry, dependencies);
- table.defaultReturnValue(-1);
- valuesTable.defaultReturnValue(-1);
+ table.defaultReturnValue(NO_VALUE);
+ valuesTable.defaultReturnValue(NO_VALUE);
}
static byte[] serializeToBytes(
@@ -201,8 +205,8 @@
switch (codec.getStrategy()) {
case MEMOIZE_BEFORE:
{
- int id = memoize(obj);
- codedOut.writeInt32NoTag(id);
+ // Deserialization determines the value of this tag based on the size of its memo table.
+ memoize(obj);
codec.serialize(this, obj, codedOut);
break;
}
@@ -213,7 +217,7 @@
// cycle, then there's now a memo entry for the parent. Don't overwrite it with a new
// id.
int cylicallyCreatedId = getMemoizedIndex(obj);
- int id = (cylicallyCreatedId != -1) ? cylicallyCreatedId : memoize(obj);
+ int id = (cylicallyCreatedId != NO_VALUE) ? cylicallyCreatedId : memoize(obj);
codedOut.writeInt32NoTag(id);
break;
}
@@ -224,7 +228,7 @@
final boolean writeBackReferenceIfMemoized(Object obj, CodedOutputStream codedOut)
throws IOException {
int memoizedIndex = getMemoizedIndex(obj);
- if (memoizedIndex == -1) {
+ if (memoizedIndex == NO_VALUE) {
return false;
}
// Subtracts 1 so it will be negative and not collide with null.
@@ -237,12 +241,12 @@
return true;
}
- /** If the value is already memoized, return its on-the-wire id; otherwise returns {@code -1}. */
+ /**
+ * If the value is already memoized, return its on-the-wire id; otherwise returns {@link
+ * #NO_VALUE}.
+ */
private int getMemoizedIndex(Object value) {
- if (value instanceof String) {
- return valuesTable.getInt(value);
- }
- return table.getInt(value);
+ return isValueType(value) ? valuesTable.getInt(value) : table.getInt(value);
}
/**
@@ -250,19 +254,20 @@
*
* <p>{@code value} must not already be present.
*/
+ @CanIgnoreReturnValue // may be called for side effect
private int memoize(Object value) {
- Preconditions.checkArgument(
- getMemoizedIndex(value) == -1, "Tried to memoize object '%s' multiple times", value);
// Ids count sequentially from 0.
int newId = table.size() + valuesTable.size();
- if (value instanceof String) {
- valuesTable.put(value, newId);
- } else {
- table.put(value, newId);
- }
+ int maybePrevious =
+ isValueType(value) ? valuesTable.put(value, newId) : table.put(value, newId);
+ checkState(maybePrevious == NO_VALUE, "Memoized object '%s' multiple times", value);
return newId;
}
+ private boolean isValueType(Object value) {
+ return value instanceof String;
+ }
+
private static void serializeToStream(
ObjectCodecRegistry codecRegistry,
ImmutableClassToInstanceMap<Object> dependencies,