Preserve comparator in `ImmutableSortedMap` serialization.

`ImmutableSortedMap` will be deserialized as `ImmutableSortedMap` only if it is ordered by natural comparator, otherwise we will restore it as `ImmutableMap`. Include the comparator in serialization and always restore the map as `ImmutableSortedMap`.

PiperOrigin-RevId: 381577433
diff --git a/src/main/java/com/google/devtools/build/lib/actions/Artifact.java b/src/main/java/com/google/devtools/build/lib/actions/Artifact.java
index f2b3b5c..9d17a22 100644
--- a/src/main/java/com/google/devtools/build/lib/actions/Artifact.java
+++ b/src/main/java/com/google/devtools/build/lib/actions/Artifact.java
@@ -43,6 +43,7 @@
 import com.google.devtools.build.lib.skyframe.serialization.SerializationContext;
 import com.google.devtools.build.lib.skyframe.serialization.SerializationException;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
+import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationConstant;
 import com.google.devtools.build.lib.starlarkbuildapi.FileApi;
 import com.google.devtools.build.lib.util.FileType;
 import com.google.devtools.build.lib.util.FileTypeSet;
@@ -130,6 +131,7 @@
   public static final Depset.ElementType TYPE = Depset.ElementType.of(Artifact.class);
 
   /** Compares artifact according to their exec paths. Sorts null values first. */
+  @SerializationConstant
   @SuppressWarnings("ReferenceEquality") // "a == b" is an optimization
   public static final Comparator<Artifact> EXEC_PATH_COMPARATOR =
       (a, b) -> {
diff --git a/src/main/java/com/google/devtools/build/lib/actions/BUILD b/src/main/java/com/google/devtools/build/lib/actions/BUILD
index a753641..7ca98f1 100644
--- a/src/main/java/com/google/devtools/build/lib/actions/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/actions/BUILD
@@ -210,6 +210,7 @@
         "//src/main/java/com/google/devtools/build/lib/skyframe:sky_functions",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec",
+        "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec:serialization-constant",
         "//src/main/java/com/google/devtools/build/lib/starlarkbuildapi",
         "//src/main/java/com/google/devtools/build/lib/util",
         "//src/main/java/com/google/devtools/build/lib/util:detailed_exit_code",
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkAction.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkAction.java
index 15b4965..7433410 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkAction.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkAction.java
@@ -14,8 +14,9 @@
 
 package com.google.devtools.build.lib.rules.cpp;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.ImmutableSortedMap;
 import com.google.devtools.build.lib.actions.AbstractAction;
 import com.google.devtools.build.lib.actions.ActionExecutionContext;
@@ -56,12 +57,13 @@
    * {@code includePath}.
    */
   public CreateIncSymlinkAction(
-      ActionOwner owner, Map<Artifact, Artifact> symlinks, Path includePath) {
-    super(
-        owner,
-        NestedSetBuilder.wrap(Order.STABLE_ORDER, symlinks.values()),
-        ImmutableSet.copyOf(symlinks.keySet()));
-    this.symlinks = ImmutableSortedMap.copyOf(symlinks, Artifact.EXEC_PATH_COMPARATOR);
+      ActionOwner owner, ImmutableSortedMap<Artifact, Artifact> symlinks, Path includePath) {
+    super(owner, NestedSetBuilder.wrap(Order.STABLE_ORDER, symlinks.values()), symlinks.keySet());
+    checkArgument(
+        symlinks.comparator().equals(Artifact.EXEC_PATH_COMPARATOR),
+        "Symlinks uses an incorrect comparator: %s",
+        symlinks.comparator());
+    this.symlinks = symlinks;
     this.includePath = includePath;
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/BUILD b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
index 47362d8..b28c5e3 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
@@ -28,6 +28,7 @@
     deps = [
         ":codec-scanning-constants",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec:registered-singleton",
+        "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec:serialization-constant",
         "//src/main/java/com/google/devtools/build/lib/unsafe:string",
         "//src/main/java/com/google/devtools/build/lib/unsafe:unsafe-provider",
         "//third_party:flogger",
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodec.java b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodec.java
index 07e337f..d5f4631 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodec.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodec.java
@@ -17,6 +17,7 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSortedMap;
 import com.google.common.collect.Ordering;
+import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationConstant;
 import com.google.protobuf.CodedInputStream;
 import com.google.protobuf.CodedOutputStream;
 import java.io.IOException;
@@ -45,6 +46,16 @@
  * ImmutableSortedMap}, arbitrary otherwise, we avoid specifying the key type as a parameter.
  */
 class ImmutableMapCodec<V> implements ObjectCodec<ImmutableMap<?, V>> {
+
+  @SuppressWarnings("unused")
+  @SerializationConstant
+  static final Comparator<?> ORDERING_NATURAL = Ordering.natural();
+
+  // In practice, the natural comparator seems to always be Ordering.natural(), but be flexible.
+  @SuppressWarnings("unused")
+  @SerializationConstant
+  static final Comparator<?> COMPARATOR_NATURAL_ORDER = Comparator.naturalOrder();
+
   @SuppressWarnings("unchecked")
   @Override
   public Class<ImmutableMap<?, V>> getEncodedClass() {
@@ -58,14 +69,9 @@
       SerializationContext context, ImmutableMap<?, V> map, CodedOutputStream codedOut)
       throws SerializationException, IOException {
     codedOut.writeInt32NoTag(map.size());
-    boolean serializeAsSortedMap = false;
-    if (map instanceof ImmutableSortedMap) {
-      Comparator<?> comparator = ((ImmutableSortedMap<?, ?>) map).comparator();
-      // In practice the comparator seems to always be Ordering.natural(), but be flexible.
-      serializeAsSortedMap =
-          comparator.equals(Ordering.natural()) || comparator.equals(Comparator.naturalOrder());
-    }
-    codedOut.writeBoolNoTag(serializeAsSortedMap);
+    Comparator<?> comparator =
+        map instanceof ImmutableSortedMap ? ((ImmutableSortedMap<?, ?>) map).comparator() : null;
+    context.serialize(comparator, codedOut);
     serializeEntries(context, map.entrySet(), codedOut);
   }
 
@@ -96,8 +102,10 @@
       throw new SerializationException("Expected non-negative length: " + length);
     }
     ImmutableMap.Builder<?, V> builder;
-    if (codedIn.readBool()) {
-      builder = deserializeEntries(ImmutableSortedMap.naturalOrder(), length, context, codedIn);
+    Comparator<?> comparator = context.deserialize(codedIn);
+    if (comparator != null) {
+      builder =
+          deserializeEntries(ImmutableSortedMap.orderedBy(comparator), length, context, codedIn);
     } else {
       builder =
           deserializeEntries(
diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkActionTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkActionTest.java
index 6a7cc19..dd8279a 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkActionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/CreateIncSymlinkActionTest.java
@@ -14,10 +14,12 @@
 
 package com.google.devtools.build.lib.rules.cpp;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.lib.actions.util.ActionsTestUtil.NULL_ACTION_OWNER;
 
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSortedMap;
 import com.google.devtools.build.lib.actions.ActionExecutionContext;
 import com.google.devtools.build.lib.actions.ActionExecutionContext.LostInputsCheck;
 import com.google.devtools.build.lib.actions.ActionKeyContext;
@@ -57,14 +59,14 @@
     Artifact c = ActionsTestUtil.createArtifact(root, "c");
     Artifact d = ActionsTestUtil.createArtifact(root, "d");
     CreateIncSymlinkAction action1 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b, c, d), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b, c, d), includePath);
     // Can't reuse the artifacts here; that would lead to DuplicateArtifactException.
     a = ActionsTestUtil.createArtifact(root, "a");
     b = ActionsTestUtil.createArtifact(root, "b");
     c = ActionsTestUtil.createArtifact(root, "c");
     d = ActionsTestUtil.createArtifact(root, "d");
     CreateIncSymlinkAction action2 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(c, d, a, b), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(c, d, a, b), includePath);
 
     assertThat(computeKey(action2)).isEqualTo(computeKey(action1));
   }
@@ -77,12 +79,12 @@
     Artifact a = ActionsTestUtil.createArtifact(root, "a");
     Artifact b = ActionsTestUtil.createArtifact(root, "b");
     CreateIncSymlinkAction action1 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), includePath);
     // Can't reuse the artifacts here; that would lead to DuplicateArtifactException.
     a = ActionsTestUtil.createArtifact(root, "a");
     b = ActionsTestUtil.createArtifact(root, "c");
     CreateIncSymlinkAction action2 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), includePath);
 
     assertThat(computeKey(action2)).isNotEqualTo(computeKey(action1));
   }
@@ -95,12 +97,12 @@
     Artifact a = ActionsTestUtil.createArtifact(root, "a");
     Artifact b = ActionsTestUtil.createArtifact(root, "b");
     CreateIncSymlinkAction action1 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), includePath);
     // Can't reuse the artifacts here; that would lead to DuplicateArtifactException.
     a = ActionsTestUtil.createArtifact(root, "c");
     b = ActionsTestUtil.createArtifact(root, "b");
     CreateIncSymlinkAction action2 =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b), includePath);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), includePath);
 
     assertThat(computeKey(action2)).isNotEqualTo(computeKey(action1));
   }
@@ -114,8 +116,8 @@
     Path symlink = rootDirectory.getRelative("out/a");
     Artifact a = ActionsTestUtil.createArtifact(root, symlink);
     Artifact b = ActionsTestUtil.createArtifact(root, "b");
-    CreateIncSymlinkAction action = new CreateIncSymlinkAction(NULL_ACTION_OWNER,
-        ImmutableMap.of(a, b), outputDir);
+    CreateIncSymlinkAction action =
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), outputDir);
     action.execute(makeDummyContext());
     symlink.stat(Symlinks.NOFOLLOW);
     assertThat(symlink.isSymbolicLink()).isTrue();
@@ -154,7 +156,7 @@
     Artifact a = ActionsTestUtil.createArtifact(root, symlink);
     Artifact b = ActionsTestUtil.createArtifact(root, "b");
     CreateIncSymlinkAction action =
-        new CreateIncSymlinkAction(NULL_ACTION_OWNER, ImmutableMap.of(a, b), outputDir);
+        new CreateIncSymlinkAction(NULL_ACTION_OWNER, symlinksMap(a, b), outputDir);
     Path extra = rootDirectory.getRelative("out/extra");
     FileSystemUtils.createEmptyFile(extra);
     assertThat(extra.exists()).isTrue();
@@ -171,4 +173,14 @@
     action.computeKey(actionKeyContext, /*artifactExpander=*/ null, fp);
     return fp.hexDigestAndReset();
   }
+
+  private static ImmutableSortedMap<Artifact, Artifact> symlinksMap(Artifact... artifacts) {
+    checkArgument(artifacts.length % 2 == 0, "Odd number of arguments: %s", artifacts.length);
+    ImmutableSortedMap.Builder<Artifact, Artifact> symlinks =
+        ImmutableSortedMap.orderedBy(Artifact.EXEC_PATH_COMPARATOR);
+    for (int i = 0; i < artifacts.length; i += 2) {
+      symlinks.put(artifacts[i], artifacts[i + 1]);
+    }
+    return symlinks.build();
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
index 5fc3970..d4ebcb9 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
@@ -18,6 +18,8 @@
         "//src/main/java/com/google/devtools/build/lib/events",
         "//src/main/java/com/google/devtools/build/lib/skyframe:precomputed_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization",
+        "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec",
+        "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/autocodec:serialization-constant",
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils",
         "//src/main/java/com/google/devtools/build/lib/vfs",
         "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodecTest.java
index 3bf9b44..57f41a5 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ImmutableMapCodecTest.java
@@ -21,12 +21,17 @@
 import com.google.common.collect.ImmutableClassToInstanceMap;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSortedMap;
+import com.google.common.collect.Ordering;
+import com.google.devtools.build.lib.skyframe.serialization.SerializationException.NoCodecException;
+import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec.VisibleForSerialization;
+import com.google.devtools.build.lib.skyframe.serialization.autocodec.SerializationConstant;
 import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
 import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester.VerificationFunction;
 import com.google.devtools.build.lib.skyframe.serialization.testutils.TestUtils;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.CodedInputStream;
 import com.google.protobuf.CodedOutputStream;
+import java.util.Comparator;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -34,6 +39,15 @@
 /** Tests for {@link ImmutableMapCodec}. */
 @RunWith(JUnit4.class)
 public class ImmutableMapCodecTest {
+
+  @SuppressWarnings("unused")
+  @SerializationConstant
+  @VisibleForSerialization
+  static final Comparator<?> ORDERING_REVERSE_NATURAL = Ordering.natural().reverse();
+
+  @SerializationConstant @VisibleForSerialization
+  static final Comparator<String> HELLO_FIRST_COMPARATOR = selectedFirstComparator("hello");
+
   @Test
   public void smoke() throws Exception {
     new SerializationTester(
@@ -55,12 +69,34 @@
   }
 
   @Test
-  public void unnaturallySortedMapComesBackUnsortedInCorrectOrder() throws Exception {
-    ImmutableMap<?, ?> deserialized =
-        TestUtils.roundTrip(ImmutableSortedMap.reverseOrder().put("a", "b").put("c", "d").build());
-    assertThat(deserialized).isInstanceOf(ImmutableMap.class);
-    assertThat(deserialized).isNotInstanceOf(ImmutableSortedMap.class);
-    assertThat(deserialized).containsExactly("c", "d", "a", "b").inOrder();
+  public void immutableSortedMapRoundTripsWithTheSameComparator() throws Exception {
+    ImmutableSortedMap<?, ?> deserialized =
+        TestUtils.roundTrip(
+            ImmutableSortedMap.orderedBy(HELLO_FIRST_COMPARATOR)
+                .put("a", "b")
+                .put("hello", "there")
+                .build());
+
+    assertThat(deserialized).containsExactly("hello", "there", "a", "b");
+    assertThat(deserialized.comparator()).isSameInstanceAs(HELLO_FIRST_COMPARATOR);
+  }
+
+  @Test
+  public void immutableSortedMapUnserializableComparatorFails() {
+    Comparator<String> comparator = selectedFirstComparator("c");
+
+    NoCodecException thrown =
+        assertThrows(
+            NoCodecException.class,
+            () ->
+                TestUtils.roundTrip(
+                    ImmutableSortedMap.<String, String>orderedBy(comparator)
+                        .put("a", "b")
+                        .put("c", "d")
+                        .build()));
+    assertThat(thrown)
+        .hasMessageThat()
+        .startsWith("No default codec available for " + comparator.getClass().getCanonicalName());
   }
 
   @Test
@@ -102,6 +138,21 @@
         .contains("Exception while deserializing value for key 'a'");
   }
 
+  private static Comparator<String> selectedFirstComparator(String first) {
+    return (a, b) -> {
+      if (a.equals(b)) {
+        return 0;
+      }
+      if (a.equals(first)) {
+        return -1;
+      }
+      if (b.equals(first)) {
+        return 1;
+      }
+      return a.compareTo(b);
+    };
+  }
+
   private static class Dummy {}
 
   private static class DummyThrowingCodec implements ObjectCodec<Dummy> {