Add type-safety check when deserializing in DynamicCodec, suggested by adonovan@.

PiperOrigin-RevId: 313327501
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
index 5973bc3..50bca8e 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodec.java
@@ -164,42 +164,43 @@
    * Deserializes a field directly into the supplied object.
    *
    * @param obj the object containing the field to deserialize. Can be an array or a plain object.
-   * @param type class of the field to deserialize
+   * @param fieldType class of the field to deserialize
    * @param offset unsafe offset into obj where the field should be written
    */
-  private static void deserializeField(
+  private void deserializeField(
       DeserializationContext context,
       CodedInputStream codedIn,
       Object obj,
-      Class<?> type,
+      Class<?> fieldType,
       long offset)
       throws SerializationException, IOException {
-    if (type.isPrimitive()) {
-      if (type.equals(boolean.class)) {
+    if (fieldType.isPrimitive()) {
+      if (fieldType.equals(boolean.class)) {
         UnsafeProvider.getInstance().putBoolean(obj, offset, codedIn.readBool());
-      } else if (type.equals(byte.class)) {
+      } else if (fieldType.equals(byte.class)) {
         UnsafeProvider.getInstance().putByte(obj, offset, codedIn.readRawByte());
-      } else if (type.equals(short.class)) {
+      } else if (fieldType.equals(short.class)) {
         ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
         UnsafeProvider.getInstance().putShort(obj, offset, buffer.getShort(0));
-      } else if (type.equals(char.class)) {
+      } else if (fieldType.equals(char.class)) {
         ByteBuffer buffer = ByteBuffer.allocate(2).put(codedIn.readRawBytes(2));
         UnsafeProvider.getInstance().putChar(obj, offset, buffer.getChar(0));
-      } else if (type.equals(int.class)) {
+      } else if (fieldType.equals(int.class)) {
         UnsafeProvider.getInstance().putInt(obj, offset, codedIn.readInt32());
-      } else if (type.equals(long.class)) {
+      } else if (fieldType.equals(long.class)) {
         UnsafeProvider.getInstance().putLong(obj, offset, codedIn.readInt64());
-      } else if (type.equals(float.class)) {
+      } else if (fieldType.equals(float.class)) {
         UnsafeProvider.getInstance().putFloat(obj, offset, codedIn.readFloat());
-      } else if (type.equals(double.class)) {
+      } else if (fieldType.equals(double.class)) {
         UnsafeProvider.getInstance().putDouble(obj, offset, codedIn.readDouble());
-      } else if (type.equals(void.class)) {
+      } else if (fieldType.equals(void.class)) {
         // Does nothing for void type.
       } else {
-        throw new UnsupportedOperationException("Unknown primitive type: " + type);
+        throw new UnsupportedOperationException(
+            "Unknown primitive field type " + fieldType + " for " + type);
       }
-    } else if (type.isArray()) {
-      if (type.getComponentType().equals(byte.class)) {
+    } else if (fieldType.isArray()) {
+      if (fieldType.getComponentType().equals(byte.class)) {
         boolean isNonNull = codedIn.readBool();
         UnsafeProvider.getInstance()
             .putObject(obj, offset, isNonNull ? codedIn.readByteArray() : null);
@@ -210,19 +211,32 @@
         UnsafeProvider.getInstance().putObject(obj, offset, null);
         return;
       }
-      Object arr = Array.newInstance(type.getComponentType(), length);
+      Object arr = Array.newInstance(fieldType.getComponentType(), length);
       UnsafeProvider.getInstance().putObject(obj, offset, arr);
-      int base = UnsafeProvider.getInstance().arrayBaseOffset(type);
-      int scale = UnsafeProvider.getInstance().arrayIndexScale(type);
+      int base = UnsafeProvider.getInstance().arrayBaseOffset(fieldType);
+      int scale = UnsafeProvider.getInstance().arrayIndexScale(fieldType);
       if (scale == 0) {
-        throw new SerializationException("Failed to get index scale for type: " + type);
+        throw new SerializationException(
+            "Failed to get index scale for field type " + fieldType + " for " + type);
       }
       for (int i = 0; i < length; ++i) {
         // Deserializes type directly into array memory.
-        deserializeField(context, codedIn, arr, type.getComponentType(), base + scale * i);
+        deserializeField(context, codedIn, arr, fieldType.getComponentType(), base + scale * i);
       }
     } else {
-      UnsafeProvider.getInstance().putObject(obj, offset, context.deserialize(codedIn));
+      Object fieldValue = context.deserialize(codedIn);
+      if (fieldValue != null && !fieldType.isInstance(fieldValue)) {
+        throw new SerializationException(
+            "Field "
+                + fieldValue
+                + " was not instance of "
+                + fieldType
+                + " (was "
+                + fieldValue.getClass()
+                + ") for "
+                + type);
+      }
+      UnsafeProvider.getInstance().putObject(obj, offset, fieldValue);
     }
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
index 17149e1..a26e18d 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/DynamicCodecTest.java
@@ -19,6 +19,9 @@
 
 import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
 import java.io.BufferedInputStream;
 import java.util.Arrays;
 import java.util.Objects;
@@ -386,7 +389,7 @@
   }
 
   @Test
-  public void testNoCodecExample() throws Exception {
+  public void testNoCodecExample() {
     ObjectCodecs codecs = new ObjectCodecs(AutoRegistry.get(), ImmutableMap.of());
     SerializationException.NoCodecException expected =
         assertThrows(
@@ -402,4 +405,78 @@
                 + "com.google.devtools.build.lib.skyframe.serialization."
                 + "DynamicCodecTest$NoCodecExample1]");
   }
+
+  private static class SpecificObject {}
+
+  private static class SpecificObjectWrapper {
+    @SuppressWarnings("unused")
+    private final SpecificObject field;
+
+    SpecificObjectWrapper(SpecificObject field) {
+      this.field = field;
+    }
+  }
+
+  @Test
+  public void overGeneralCodec() throws Exception {
+    // Class must be hidden from other tests.
+    class OverGeneralCodec implements ObjectCodec<Object> {
+      @Override
+      public Class<?> getEncodedClass() {
+        return Object.class;
+      }
+
+      @Override
+      public void serialize(SerializationContext context, Object obj, CodedOutputStream codedOut) {}
+
+      @Override
+      public Object deserialize(DeserializationContext context, CodedInputStream codedIn) {
+        return new Object();
+      }
+    }
+    ObjectCodecRegistry registry =
+        ObjectCodecRegistry.newBuilder()
+            .add(new DynamicCodec(SpecificObjectWrapper.class))
+            .add(new OverGeneralCodec())
+            .build();
+    ObjectCodecs codecs = new ObjectCodecs(registry);
+    ByteString bytes = codecs.serializeMemoized(new SpecificObjectWrapper(new SpecificObject()));
+    SerializationException expected =
+        assertThrows(SerializationException.class, () -> codecs.deserializeMemoized(bytes));
+    assertThat(expected)
+        .hasMessageThat()
+        .contains(
+            "was not instance of class "
+                + "com.google.devtools.build.lib.skyframe.serialization."
+                + "DynamicCodecTest$SpecificObject");
+  }
+
+  @Test
+  public void overGeneralCodecOkWhenNull() throws Exception {
+    // Class must be hidden from other tests.
+    class OverGeneralCodec implements ObjectCodec<Object> {
+      @Override
+      public Class<?> getEncodedClass() {
+        return Object.class;
+      }
+
+      @Override
+      public void serialize(SerializationContext context, Object obj, CodedOutputStream codedOut) {}
+
+      @Override
+      public Object deserialize(DeserializationContext context, CodedInputStream codedIn) {
+        return new Object();
+      }
+    }
+    ObjectCodecRegistry registry =
+        ObjectCodecRegistry.newBuilder()
+            .add(new DynamicCodec(SpecificObjectWrapper.class))
+            .add(new OverGeneralCodec())
+            .build();
+    ObjectCodecs codecs = new ObjectCodecs(registry);
+    ByteString bytes = codecs.serializeMemoized(new SpecificObjectWrapper(null));
+    Object deserialized = codecs.deserializeMemoized(bytes);
+    assertThat(deserialized).isInstanceOf(SpecificObjectWrapper.class);
+    assertThat(((SpecificObjectWrapper) deserialized).field).isNull();
+  }
 }