Add an IdentityHashMap to the BuildOptions.OptionsDiffForReconstruction codec.

PiperOrigin-RevId: 196310244
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
index 5612359..87ed80b 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/config/BuildOptions.java
@@ -27,6 +27,10 @@
 import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.runtime.proto.InvocationPolicyOuterClass.InvocationPolicy;
+import com.google.devtools.build.lib.skyframe.serialization.DeserializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.ObjectCodec;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
 import com.google.devtools.build.lib.util.Fingerprint;
 import com.google.devtools.build.lib.util.OrderedSetMultimap;
@@ -36,6 +40,10 @@
 import com.google.devtools.common.options.OptionsClassProvider;
 import com.google.devtools.common.options.OptionsParser;
 import com.google.devtools.common.options.OptionsParsingException;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -43,6 +51,7 @@
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.IdentityHashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -424,11 +433,11 @@
     if (diff.areSame()) {
       return OptionsDiffForReconstruction.getEmpty(first.fingerprint, second.computeChecksum());
     }
-    HashMap<Class<? extends FragmentOptions>, Map<String, Object>> differingOptions =
-        new HashMap<>(diff.differingOptions.keySet().size());
+    LinkedHashMap<Class<? extends FragmentOptions>, Map<String, Object>> differingOptions =
+        new LinkedHashMap<>(diff.differingOptions.keySet().size());
     for (Class<? extends FragmentOptions> clazz : diff.differingOptions.keySet()) {
       Collection<OptionDefinition> fields = diff.differingOptions.get(clazz);
-      HashMap<String, Object> valueMap = new HashMap<>(fields.size());
+      LinkedHashMap<String, Object> valueMap = new LinkedHashMap<>(fields.size());
       for (OptionDefinition optionDefinition : fields) {
         Object secondValue;
         try {
@@ -543,7 +552,6 @@
    * another: the full fragments of the second one, the fragment classes of the first that should be
    * omitted, and the values of any fields that should be changed.
    */
-  @AutoCodec
   public static class OptionsDiffForReconstruction {
     private final Map<Class<? extends FragmentOptions>, Map<String, Object>> differingOptions;
     private final ImmutableSet<Class<? extends FragmentOptions>> extraFirstFragmentClasses;
@@ -551,7 +559,6 @@
     private final byte[] baseFingerprint;
     private final String checksum;
 
-    @AutoCodec.VisibleForSerialization
     OptionsDiffForReconstruction(
         Map<Class<? extends FragmentOptions>, Map<String, Object>> differingOptions,
         ImmutableSet<Class<? extends FragmentOptions>> extraFirstFragmentClasses,
@@ -613,7 +620,9 @@
       OptionsDiffForReconstruction that = (OptionsDiffForReconstruction) o;
       return differingOptions.equals(that.differingOptions)
           && extraFirstFragmentClasses.equals(that.extraFirstFragmentClasses)
-          && this.extraSecondFragments.equals(that.extraSecondFragments);
+          && this.extraSecondFragments.equals(that.extraSecondFragments)
+          && Arrays.equals(this.baseFingerprint, that.baseFingerprint)
+          && this.checksum.equals(that.checksum);
     }
 
     @Override
@@ -626,7 +635,91 @@
 
     @Override
     public int hashCode() {
-      return Objects.hash(differingOptions, extraFirstFragmentClasses, extraSecondFragments);
+      return Objects.hash(
+          differingOptions,
+          extraFirstFragmentClasses,
+          extraSecondFragments,
+          Arrays.hashCode(baseFingerprint),
+          checksum);
+    }
+  }
+
+  /**
+   * Hand-rolled Codec so we can cache the byte representation of a {@link
+   * BuildOptions.OptionsDiffForReconstruction} object because serialization is expensive.
+   */
+  @VisibleForTesting
+  static class OptionsDiffForReconstructionCodec
+      implements ObjectCodec<OptionsDiffForReconstruction> {
+
+    @Override
+    public void serialize(
+        SerializationContext context,
+        BuildOptions.OptionsDiffForReconstruction input,
+        CodedOutputStream codedOut)
+        throws SerializationException, IOException {
+      context = context.getNewNonMemoizingContext();
+      // We get this cache from our context because there can be different ObjectCodecRegistry's for
+      // SkyKeys and SkyValues.
+      @SuppressWarnings("unchecked")
+      IdentityHashMap<OptionsDiffForReconstruction, byte[]> cache =
+          context.getDependency(IdentityHashMap.class);
+      if (cache.containsKey(input)) {
+        byte[] rawBytes = cache.get(input);
+        codedOut.writeInt32NoTag(rawBytes.length);
+        codedOut.writeRawBytes(cache.get(input));
+      } else {
+        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+        CodedOutputStream codedOutputStream = CodedOutputStream.newInstance(byteArrayOutputStream);
+        context.serialize(input.differingOptions, codedOutputStream);
+        context.serialize(input.extraFirstFragmentClasses, codedOutputStream);
+        context.serialize(input.extraSecondFragments, codedOutputStream);
+        if (input.baseFingerprint != null) {
+          codedOutputStream.writeBoolNoTag(true);
+          codedOutputStream.writeInt32NoTag(input.baseFingerprint.length);
+          codedOutputStream.writeRawBytes(input.baseFingerprint);
+        } else {
+          codedOutputStream.writeBoolNoTag(false);
+        }
+        context.serialize(input.checksum, codedOutputStream);
+        codedOutputStream.flush();
+        byteArrayOutputStream.flush();
+        byte[] serializedBytes = byteArrayOutputStream.toByteArray();
+        cache.put(input, serializedBytes);
+        codedOut.writeInt32NoTag(serializedBytes.length);
+        codedOut.writeRawBytes(serializedBytes);
+        codedOut.flush();
+      }
+    }
+
+    @Override
+    public BuildOptions.OptionsDiffForReconstruction deserialize(
+        DeserializationContext context, CodedInputStream codedIn)
+        throws SerializationException, IOException {
+      byte[] serializedBytes = codedIn.readRawBytes(codedIn.readInt32());
+      CodedInputStream codedInputStream = CodedInputStream.newInstance(serializedBytes);
+      context = context.getNewNonMemoizingContext();
+      Map<Class<? extends FragmentOptions>, Map<String, Object>> differingOptions =
+          context.deserialize(codedInputStream);
+      ImmutableSet<Class<? extends FragmentOptions>> extraFirstFragmentClasses =
+          context.deserialize(codedInputStream);
+      ImmutableList<FragmentOptions> extraSecondFragments = context.deserialize(codedInputStream);
+      byte[] baseFingerprint = null;
+      if (codedInputStream.readBool()) {
+        baseFingerprint = codedInputStream.readRawBytes(codedInputStream.readInt32());
+      }
+      String checksum = context.deserialize(codedInputStream);
+      return new OptionsDiffForReconstruction(
+          differingOptions,
+          extraFirstFragmentClasses,
+          extraSecondFragments,
+          baseFingerprint,
+          checksum);
+    }
+
+    @Override
+    public Class<BuildOptions.OptionsDiffForReconstruction> getEncodedClass() {
+      return BuildOptions.OptionsDiffForReconstruction.class;
     }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
index 801dfc5..45bde75 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/DeserializationContext.java
@@ -120,4 +120,8 @@
   public DeserializationContext getNewMemoizingContext() {
     return new DeserializationContext(this.registry, this.dependencies, new Deserializer());
   }
+
+  public DeserializationContext getNewNonMemoizingContext() {
+    return new DeserializationContext(this.registry, this.dependencies, null);
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
index 468a362..c16c647 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/SerializationContext.java
@@ -131,6 +131,11 @@
         this.registry, this.dependencies, new Memoizer.Serializer(), allowFuturesToBlockWritingOn);
   }
 
+  public SerializationContext getNewNonMemoizingContext() {
+    return new SerializationContext(
+        this.registry, this.dependencies, null, this.allowFuturesToBlockWritingOn);
+  }
+
   /**
    * Register a {@link ListenableFuture} that must complete successfully before the serialized bytes
    * generated using this context can be written remotely. Failure of the future implies a bug or
diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD
index fb3d52c..1502f25 100644
--- a/src/test/java/com/google/devtools/build/lib/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/BUILD
@@ -704,6 +704,7 @@
         "//src/main/java/com/google/devtools/build/lib:java-compilation",
         "//src/main/java/com/google/devtools/build/lib:java-rules",
         "//src/main/java/com/google/devtools/build/lib:packages",
+        "//src/main/java/com/google/devtools/build/lib:proto-rules",
         "//src/main/java/com/google/devtools/build/lib:python-rules",
         "//src/main/java/com/google/devtools/build/lib:util",
         "//src/main/java/com/google/devtools/build/lib/rules/cpp",
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/config/BuildConfigurationTest.java b/src/test/java/com/google/devtools/build/lib/analysis/config/BuildConfigurationTest.java
index aa36ebf..13d4e7a 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/config/BuildConfigurationTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/config/BuildConfigurationTest.java
@@ -32,6 +32,7 @@
 import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
 import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.common.options.Options;
+import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.regex.Pattern;
 import org.junit.Test;
@@ -452,6 +453,9 @@
                 "--define",
                 "#a=pounda"))
         .addDependency(FileSystem.class, getScratch().getFileSystem())
+        .addDependency(
+            IdentityHashMap.class,
+            new IdentityHashMap<BuildOptions.OptionsDiffForReconstruction, byte[]>())
         .setVerificationFunction(BuildConfigurationTest::verifyDeserialized)
         .runTests();
   }
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/config/BuildOptionsTest.java b/src/test/java/com/google/devtools/build/lib/analysis/config/BuildOptionsTest.java
index 421db8e..7783782 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/config/BuildOptionsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/config/BuildOptionsTest.java
@@ -20,7 +20,10 @@
 import com.google.devtools.build.lib.analysis.config.BuildOptions.OptionsDiff;
 import com.google.devtools.build.lib.analysis.config.BuildOptions.OptionsDiffForReconstruction;
 import com.google.devtools.build.lib.rules.cpp.CppOptions;
+import com.google.devtools.build.lib.rules.proto.ProtoConfiguration;
+import com.google.devtools.build.lib.skyframe.serialization.testutils.ObjectCodecTester;
 import com.google.devtools.common.options.OptionsParser;
+import java.util.IdentityHashMap;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -31,7 +34,8 @@
 @RunWith(JUnit4.class)
 public class BuildOptionsTest {
   private static final ImmutableList<Class<? extends FragmentOptions>> TEST_OPTIONS =
-      ImmutableList.<Class<? extends FragmentOptions>>of(BuildConfiguration.Options.class);
+      ImmutableList.<Class<? extends FragmentOptions>>of(
+          BuildConfiguration.Options.class, ProtoConfiguration.Options.class);
 
   @Test
   public void optionSetCaching() {
@@ -165,4 +169,32 @@
     assertThat(otherFragment.applyDiff(BuildOptions.diffForReconstruction(otherFragment, one)))
         .isEqualTo(one);
   }
+
+  @Test
+  public void testCodec() throws Exception {
+    BuildOptions one =
+        BuildOptions.of(
+            TEST_OPTIONS,
+            "--compilation_mode=opt",
+            "cpu=k8",
+            "--proto_compiler=//net/proto2/compiler/public:protocol_compiler",
+            "--proto_toolchain_for_java=//tools/proto/toolchains:java");
+    BuildOptions two =
+        BuildOptions.of(
+            TEST_OPTIONS,
+            "--compilation_mode=dbg",
+            "cpu=k8",
+            "--proto_compiler=@com_google_protobuf//:protoc");
+    ObjectCodecTester.newBuilder(new BuildOptions.OptionsDiffForReconstructionCodec())
+        .addSubjects(BuildOptions.diffForReconstruction(one, two))
+        .addDependency(
+            IdentityHashMap.class,
+            new IdentityHashMap<BuildOptions.OptionsDiffForReconstruction, byte[]>())
+        .skipBadDataTest() // Bad data doesn't make sense with our caching.
+        .verificationFunction(
+            ((original, deserialized) -> {
+              assertThat(original).isEqualTo(deserialized);
+            }))
+        .buildAndRunTests();
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/exec/MiddlemanActionTest.java b/src/test/java/com/google/devtools/build/lib/exec/MiddlemanActionTest.java
index c49d5e5..dade507 100644
--- a/src/test/java/com/google/devtools/build/lib/exec/MiddlemanActionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/exec/MiddlemanActionTest.java
@@ -23,14 +23,11 @@
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.actions.MiddlemanAction;
 import com.google.devtools.build.lib.actions.MiddlemanFactory;
-import com.google.devtools.build.lib.actions.OutputBaseSupplier;
 import com.google.devtools.build.lib.analysis.util.AnalysisTestUtil;
 import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
 import com.google.devtools.build.lib.cmdline.RepositoryName;
-import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
 import com.google.devtools.build.lib.testutil.Suite;
 import com.google.devtools.build.lib.testutil.TestSpec;
-import com.google.devtools.build.lib.vfs.FileSystem;
 import java.util.ArrayList;
 import java.util.Arrays;
 import org.junit.Before;
@@ -138,20 +135,4 @@
     assertThat(Sets.newHashSet(middlemanActionForD.getOutputs()))
         .isNotEqualTo(Sets.newHashSet(middlemanActionForC.getOutputs()));
   }
-
-  @Test
-  public void testCodec() throws Exception {
-    new SerializationTester(getGeneratingAction(middle))
-        .addDependency(FileSystem.class, scratch.getFileSystem())
-        .addDependency(OutputBaseSupplier.class, () -> outputBase)
-        .setVerificationFunction(MiddlemanActionTest::verifyEquivalent)
-        .runTests();
-  }
-
-  private static void verifyEquivalent(MiddlemanAction first, MiddlemanAction second) {
-    assertThat(first.getActionType()).isEqualTo(second.getActionType());
-    assertThat(first.getInputs()).isEqualTo(second.getInputs());
-    assertThat(first.getOutputs()).isEqualTo(second.getOutputs());
-    assertThat(first.getOwner()).isEqualTo(second.getOwner());
-  }
 }