Add type-safety check when deserializing in DynamicCodec, suggested by adonovan@.
PiperOrigin-RevId: 313327501
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
index 5973bc3..50bca8e 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
@@ -164,42 +164,43 @@
* Deserializes a field directly into the supplied object.
*
* @param obj the object containing the field to deserialize. Can be an array or a plain object.
- * @param type class of the field to deserialize
+ * @param fieldType class of the field to deserialize
* @param offset unsafe offset into obj where the field should be written
*/
- private static void deserializeField(
+ private void deserializeField(
DeserializationContext context,
CodedInputStream codedIn,
Object obj,
- Class<?> type,
+ Class<?> fieldType,
long offset)
throws SerializationException, IOException {
- if (type.isPrimitive()) {
- if (type.equals(boolean.class)) {
+ if (fieldType.isPrimitive()) {
+ if (fieldType.equals(boolean.class)) {
UnsafeProvider.getInstance().putBoolean(obj, offset, codedIn.readBool());
- } else if (type.equals(byte.class)) {
+ } else if (fieldType.equals(byte.class)) {
UnsafeProvider.getInstance().putByte(obj, offset, codedIn.readRawByte());
- } else if (type.equals(short.class)) {
+ } else if (fieldType.equals(short.class)) {
ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
UnsafeProvider.getInstance().putShort(obj, offset, buffer.getShort(0));
- } else if (type.equals(char.class)) {
+ } else if (fieldType.equals(char.class)) {
ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
UnsafeProvider.getInstance().putChar(obj, offset, buffer.getChar(0));
- } else if (type.equals(int.class)) {
+ } else if (fieldType.equals(int.class)) {
UnsafeProvider.getInstance().putInt(obj, offset, codedIn.readInt32());
- } else if (type.equals(long.class)) {
+ } else if (fieldType.equals(long.class)) {
UnsafeProvider.getInstance().putLong(obj, offset, codedIn.readInt64());
- } else if (type.equals(float.class)) {
+ } else if (fieldType.equals(float.class)) {
UnsafeProvider.getInstance().putFloat(obj, offset, codedIn.readFloat());
- } else if (type.equals(double.class)) {
+ } else if (fieldType.equals(double.class)) {
UnsafeProvider.getInstance().putDouble(obj, offset, codedIn.readDouble());
- } else if (type.equals(void.class)) {
+ } else if (fieldType.equals(void.class)) {
// Does nothing for void type.
} else {
- throw new UnsupportedOperationException("Unknown primitive type: " + type);
+ throw new UnsupportedOperationException(
+ "Unknown primitive field type " + fieldType + " for " + type);
}
- } else if (type.isArray()) {
- if (type.getComponentType().equals(byte.class)) {
+ } else if (fieldType.isArray()) {
+ if (fieldType.getComponentType().equals(byte.class)) {
boolean isNonNull = codedIn.readBool();
UnsafeProvider.getInstance()
.putObject(obj, offset, isNonNull ? codedIn.readByteArray() : null);
@@ -210,19 +211,32 @@
UnsafeProvider.getInstance().putObject(obj, offset, null);
return;
}
- Object arr = Array.newInstance(type.getComponentType(), length);
+ Object arr = Array.newInstance(fieldType.getComponentType(), length);
UnsafeProvider.getInstance().putObject(obj, offset, arr);
- int base = UnsafeProvider.getInstance().arrayBaseOffset(type);
- int scale = UnsafeProvider.getInstance().arrayIndexScale(type);
+ int base = UnsafeProvider.getInstance().arrayBaseOffset(fieldType);
+ int scale = UnsafeProvider.getInstance().arrayIndexScale(fieldType);
if (scale == 0) {
- throw new SerializationException("Failed to get index scale for type: " + type);
+ throw new SerializationException(
+ "Failed to get index scale for field type " + fieldType + " for " + type);
}
for (int i = 0; i < length; ++i) {
// Deserializes type directly into array memory.
- deserializeField(context, codedIn, arr, type.getComponentType(), base + scale * i);
+ deserializeField(context, codedIn, arr, fieldType.getComponentType(), base + scale * i);
}
} else {
- UnsafeProvider.getInstance().putObject(obj, offset, context.deserialize(codedIn));
+ Object fieldValue = context.deserialize(codedIn);
+ if (fieldValue != null && !fieldType.isInstance(fieldValue)) {
+ throw new SerializationException(
+ "Field "
+ + fieldValue
+ + " was not instance of "
+ + fieldType
+ + " (was "
+ + fieldValue.getClass()
+ + ") for "
+ + type);
+ }
+ UnsafeProvider.getInstance().putObject(obj, offset, fieldValue);
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
index 17149e1..a26e18d 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
@@ -19,6 +19,9 @@
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
import java.io.BufferedInputStream;
import java.util.Arrays;
import java.util.Objects;
@@ -386,7 +389,7 @@
}
@Test
- public void testNoCodecExample() throws Exception {
+ public void testNoCodecExample() {
ObjectCodecs codecs = new ObjectCodecs(AutoRegistry.get(), ImmutableMap.of());
SerializationException.NoCodecException expected =
assertThrows(
@@ -402,4 +405,78 @@
+ "com.google.devtools.build.lib.skyframe.serialization."
+ "DynamicCodecTest$NoCodecExample1]");
}
+
+ private static class SpecificObject {}
+
+ private static class SpecificObjectWrapper {
+ @SuppressWarnings("unused")
+ private final SpecificObject field;
+
+ SpecificObjectWrapper(SpecificObject field) {
+ this.field = field;
+ }
+ }
+
+ @Test
+ public void overGeneralCodec() throws Exception {
+ // Class must be hidden from other tests.
+ class OverGeneralCodec implements ObjectCodec<Object> {
+ @Override
+ public Class<?> getEncodedClass() {
+ return Object.class;
+ }
+
+ @Override
+ public void serialize(SerializationContext context, Object obj, CodedOutputStream codedOut) {}
+
+ @Override
+ public Object deserialize(DeserializationContext context, CodedInputStream codedIn) {
+ return new Object();
+ }
+ }
+ ObjectCodecRegistry registry =
+ ObjectCodecRegistry.newBuilder()
+ .add(new DynamicCodec(SpecificObjectWrapper.class))
+ .add(new OverGeneralCodec())
+ .build();
+ ObjectCodecs codecs = new ObjectCodecs(registry);
+ ByteString bytes = codecs.serializeMemoized(new SpecificObjectWrapper(new SpecificObject()));
+ SerializationException expected =
+ assertThrows(SerializationException.class, () -> codecs.deserializeMemoized(bytes));
+ assertThat(expected)
+ .hasMessageThat()
+ .contains(
+ "was not instance of class "
+ + "com.google.devtools.build.lib.skyframe.serialization."
+ + "DynamicCodecTest$SpecificObject");
+ }
+
+ @Test
+ public void overGeneralCodecOkWhenNull() throws Exception {
+ // Class must be hidden from other tests.
+ class OverGeneralCodec implements ObjectCodec<Object> {
+ @Override
+ public Class<?> getEncodedClass() {
+ return Object.class;
+ }
+
+ @Override
+ public void serialize(SerializationContext context, Object obj, CodedOutputStream codedOut) {}
+
+ @Override
+ public Object deserialize(DeserializationContext context, CodedInputStream codedIn) {
+ return new Object();
+ }
+ }
+ ObjectCodecRegistry registry =
+ ObjectCodecRegistry.newBuilder()
+ .add(new DynamicCodec(SpecificObjectWrapper.class))
+ .add(new OverGeneralCodec())
+ .build();
+ ObjectCodecs codecs = new ObjectCodecs(registry);
+ ByteString bytes = codecs.serializeMemoized(new SpecificObjectWrapper(null));
+ Object deserialized = codecs.deserializeMemoized(bytes);
+ assertThat(deserialized).isInstanceOf(SpecificObjectWrapper.class);
+ assertThat(((SpecificObjectWrapper) deserialized).field).isNull();
+ }
}