Optimise Depsets in Starlark providers

Add unit tests for provider optimisation

On schemaful providers store predicted element classes of a Depset next to the schema and encode the actual Depset with a NestedSet.

The prediction is based on the first non-empty Depset in a provider. The optimisation either works or if it doesn't it causes no harm.

PiperOrigin-RevId: 511732456
Change-Id: I1ce34d9d117645429759d49a7fb461becdb71aa7
diff --git a/src/main/java/com/google/devtools/build/lib/collect/nestedset/Depset.java b/src/main/java/com/google/devtools/build/lib/collect/nestedset/Depset.java
index 54326e9..1d62db5 100644
--- a/src/main/java/com/google/devtools/build/lib/collect/nestedset/Depset.java
+++ b/src/main/java/com/google/devtools/build/lib/collect/nestedset/Depset.java
@@ -334,6 +334,11 @@
     return ElementType.of(elemClass);
   }
 
+  @Nullable
+  public Class<?> getElementClass() {
+    return elemClass;
+  }
+
   @Override
   public String toString() {
     return Starlark.repr(this);
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfoWithSchema.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfoWithSchema.java
index 2df831b..b5a810d 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfoWithSchema.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfoWithSchema.java
@@ -28,14 +28,21 @@
 import net.starlark.java.syntax.Location;
 import net.starlark.java.syntax.TokenKind;
 
-/** A struct-like Info (provider instance) for providers defined in Starlark that have a schema. */
+/**
+ * A struct-like Info (provider instance) for providers defined in Starlark that have a schema.
+ *
+ * <p>Maintainer's note: This class is memory-optimized in a way that can cause profiling
+ * instability in some pathological cases. See {@link StarlarkProvider#optimizeField} for more
+ * information.
+ */
 public class StarlarkInfoWithSchema extends StarlarkInfo {
   private final StarlarkProvider provider;
 
-  // For each field in provider.getFields the table contains on corresponding position either null
-  // or a legal Starlark value
+  // For each field in provider.getFields the table contains on corresponding position either null,
+  // a legal Starlark value, or an optimized value (see StarlarkProvider#optimizeField).
   private final Object[] table;
 
+  // `table` elements should already be optimized by caller, see StarlarkProvider#optimizeField
   private StarlarkInfoWithSchema(
       StarlarkProvider provider, Object[] table, @Nullable Location loc) {
     super(loc);
@@ -68,7 +75,7 @@
               "got multiple values for parameter %s in call to instantiate provider %s",
               table[i], provider.getPrintableName());
         }
-        valueTable[pos] = table[i + 1];
+        valueTable[pos] = provider.optimizeField(pos, table[i + 1]);
       } else {
         if (unexpected == null) {
           unexpected = new ArrayList<>();
@@ -106,7 +113,9 @@
       return false;
     }
     for (int i = 0; i < table.length; i++) {
-      if (table[i] != null && !Starlark.isImmutable(table[i])) {
+      if (table[i] != null
+          && !(provider.isOptimised(i, table[i]) // optimised fields might not be Starlark values
+              || Starlark.isImmutable(table[i]))) {
         return false;
       }
     }
@@ -118,7 +127,7 @@
   public Object getValue(String name) {
     ImmutableList<String> fields = provider.getFields();
     int i = Collections.binarySearch(fields, name);
-    return i >= 0 ? table[i] : null;
+    return i >= 0 ? provider.retrieveOptimizedField(i, table[i]) : null;
   }
 
   @Nullable
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
index 17e915b..2e5f31e 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
@@ -18,12 +18,16 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.collect.nestedset.Depset;
+import com.google.devtools.build.lib.collect.nestedset.Depset.ElementType;
+import com.google.devtools.build.lib.collect.nestedset.NestedSet;
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.util.Fingerprint;
 import com.google.errorprone.annotations.CanIgnoreReturnValue;
 import java.util.Collection;
 import java.util.Map;
 import java.util.Objects;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 import javax.annotation.Nullable;
 import net.starlark.java.eval.Dict;
 import net.starlark.java.eval.EvalException;
@@ -70,6 +74,32 @@
   @Nullable private Key key;
 
   /**
+   * For schemaful providers, an array of metadata concerning depset optimization.
+   *
+   * <p>Each index in the array holds an optional (nullable) depset element type. The value at that
+   * index is initialized to be the element type of the first non-empty Depset to ever be stored in
+   * the corresponding field from {@link #schema} on any instance of this provider, globally. If no
+   * depsets (or only empty depsets) are ever stored in a field, the value at its index in this
+   * array will remain null.
+   *
+   * <p>Whenever a field is stored in an instance of this provider type, if the value is a depset
+   * whose element type matches the one stored in this array, it is optimized by unwrapping it down
+   * to its {@code NestedSet}. Upon retrieval, the depset wrapper is reconstructed using this saved
+   * element type.
+   *
+   * <p>The optimization may (harmlessly) fail to apply for provider fields that are not strongly
+   * typed across all instances.
+   *
+   * <p>For large builds, this optimization has been observed to save half a percent in retained
+   * heap.
+   *
+   * <p>In the future, the ad hoc heuristic of examining the first stored non-empty depset might be
+   * replaced by stronger type information in the provider's Starlark declaration. However, this
+   * optimization would remain relevant for provider declarations that do not supply such type info.
+   */
+  @Nullable private transient AtomicReferenceArray<Class<?>> depsetTypePredictor;
+
+  /**
    * Returns a new empty builder.
    *
    * <p>By default (unless {@link Builder#setExported} is called), the builder will build a provider
@@ -157,6 +187,9 @@
     this.schema = schema;
     this.init = init;
     this.key = key;
+    if (schema != null) {
+      depsetTypePredictor = new AtomicReferenceArray<>(schema.size());
+    }
   }
 
   private static Object[] toNamedArgs(Object value, String descriptionForError)
@@ -324,6 +357,67 @@
   }
 
   /**
+   * For schemaful providers, given a value to store in the field identified by {@code index},
+   * returns a possibly optimized version of the value. The result (optimized or not) should be
+   * decoded by {@link #retrieveOptimizedField}.
+   *
+   * <p>Mutable values are never optimized.
+   */
+  Object optimizeField(int index, Object value) {
+    if (value instanceof Depset) {
+      Preconditions.checkArgument(depsetTypePredictor != null);
+      Depset depset = (Depset) value;
+      if (depset.isEmpty()) {
+        // Most empty depsets have the empty (null) type. We can't store this type because it
+        // would clash with whatever the actual element type is for non-empty depsets in that
+        // field. So instead just store the optimized (unwrapped) NestedSet without any type
+        // information, and assume it's the empty type upon retrieval.
+        //
+        // This only loses information in the relatively rare case of a native-constructed empty
+        // depset with a type restriction (e.g. empty set of artifacts). In that scenario, an
+        // empty depset retrieved from the provider may "incorrectly" allow itself to participate
+        // in a union with depsets of other types, whereas the original depset would trigger a
+        // Starlark eval error. This is a user-observable difference but a very minor one; the
+        // hazard would be logical errors that are masked by the provider machinery but triggered
+        // by a refactoring of Starlark code. See TODO in Depset#of(Class, NestedSet) for notes
+        // about eliminating this semantic confusion.
+        //
+        // This problem shouldn't arise for non-empty depsets since distinct non-empty element
+        // types are not compatible with one another (i.e. there's no Depset<Any> schema).
+        return depset.getSet();
+      }
+      Class<?> elementClass = depset.getElementClass();
+      if (depsetTypePredictor.compareAndExchange(index, null, elementClass) == elementClass) {
+        return depset.getSet();
+      }
+    }
+    return value;
+  }
+
+  Object retrieveOptimizedField(int index, Object value) {
+    if (value instanceof NestedSet<?>) {
+      // We subvert Depset.of()'s static type checking for consistency between the type token and
+      // NestedSet type. This is safe because these values came from a previous Depset, so we
+      // already know they're consistent.
+      @SuppressWarnings("unchecked")
+      NestedSet<Object> nestedSet = (NestedSet<Object>) value;
+      if (nestedSet.isEmpty()) {
+        // This matches empty depsets created in Starlark with `depset()`. For natively created
+        // empty depsets it may change elementClass to null.
+        return Depset.of(ElementType.EMPTY, nestedSet);
+      }
+      @SuppressWarnings("unchecked") // can't parametrize Class literal by a non-raw type
+      Depset depset = Depset.of((Class<Object>) depsetTypePredictor.get(index), nestedSet);
+      return depset;
+    }
+    return value;
+  }
+
+  boolean isOptimised(int index, Object value) {
+    return value instanceof NestedSet<?>;
+  }
+
+  /**
    * A serializable representation of Starlark-defined {@link StarlarkProvider} that uniquely
    * identifies all {@link StarlarkProvider}s that are exposed to SkyFrame.
    */
diff --git a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
index de704a4..18aac04 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
@@ -15,6 +15,7 @@
 package com.google.devtools.build.lib.packages;
 
 import static com.google.common.truth.Truth.assertThat;
+import static com.google.devtools.build.lib.collect.nestedset.Order.STABLE_ORDER;
 import static org.junit.Assert.assertThrows;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -23,12 +24,15 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.common.testing.EqualsTester;
 import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.collect.nestedset.Depset;
+import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
 import net.starlark.java.eval.Dict;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.Mutability;
 import net.starlark.java.eval.Starlark;
 import net.starlark.java.eval.StarlarkCallable;
 import net.starlark.java.eval.StarlarkInt;
+import net.starlark.java.eval.StarlarkList;
 import net.starlark.java.eval.StarlarkSemantics;
 import net.starlark.java.eval.StarlarkThread;
 import net.starlark.java.eval.Tuple;
@@ -239,6 +243,242 @@
     assertThat(provider.getFields()).containsExactly("a", "b", "c").inOrder();
   }
 
+  /**
+   * Tests the safe storage and retrieval of depsets, which may be optimized to nested sets in the
+   * internal representation.
+   */
+  @Test
+  public void schemafulProvider_withDepset() throws Exception {
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN).setSchema(ImmutableList.of("field")).build();
+    StarlarkInfo instance1;
+    StarlarkInfo instance2;
+    StarlarkInfo instance3;
+    StarlarkInfo instance4;
+    StarlarkInfo instance5;
+    StarlarkInfo instance6;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      // Instantiates provider with values of different types all in the same field.
+      // Instance with an empty depset of string
+      instance1 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of(
+                      "field", Depset.of(String.class, NestedSetBuilder.emptySet(STABLE_ORDER))));
+      // Instance with a non-empty depset of string
+      instance2 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of(
+                      "field",
+                      Depset.of(String.class, NestedSetBuilder.create(STABLE_ORDER, "foo"))));
+      // Instance with a non-empty depset of int
+      instance3 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of(
+                      "field",
+                      Depset.of(
+                          StarlarkInt.class,
+                          NestedSetBuilder.create(STABLE_ORDER, StarlarkInt.of(1)))));
+      // Instance with a string (not a depset)
+      instance4 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of("field", "foo"));
+      // Instance with a None
+      instance5 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of("field", Starlark.NONE));
+      // Instance with the field not set
+      instance6 =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of());
+    }
+
+    assertThat(instance1.getValue("field")).isInstanceOf(Depset.class);
+    assertThat(((Depset) instance1.getValue("field")).isEmpty()).isTrue();
+    assertThat(instance2.getValue("field")).isInstanceOf(Depset.class);
+    assertThat(((Depset) instance2.getValue("field")).getElementClass()).isEqualTo(String.class);
+    assertThat(((Depset) instance2.getValue("field")).toList()).containsExactly("foo");
+    assertThat(instance3.getValue("field")).isInstanceOf(Depset.class);
+    assertThat(((Depset) instance3.getValue("field")).getElementClass())
+        .isEqualTo(StarlarkInt.class);
+    assertThat(((Depset) instance3.getValue("field")).toList()).containsExactly(StarlarkInt.of(1));
+    assertThat(instance4.getValue("field")).isEqualTo("foo");
+    assertThat(instance5.getValue("field")).isEqualTo(Starlark.NONE);
+    assertThat(instance6.getValue("field")).isNull();
+  }
+
+  @Test
+  public void schemafulProvider_mutable() throws Exception {
+    StarlarkProvider.Key key =
+        new StarlarkProvider.Key(Label.parseCanonical("//foo:bar.bzl"), "prov");
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a"))
+            .setExported(key)
+            .build();
+    StarlarkInfo instance;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      instance =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of("a", StarlarkList.of(mu, "x")));
+      @SuppressWarnings("unchecked")
+      StarlarkList<String> list = (StarlarkList<String>) instance.getValue("a");
+
+      list.addElement("y"); // verifies the fields of the provider instance are mutable
+      assertThat(instance.isImmutable()).isFalse();
+    }
+
+    @SuppressWarnings("unchecked")
+    StarlarkList<String> list = (StarlarkList<String>) instance.getValue("a");
+    assertThat((Iterable<?>) list).containsExactly("x", "y");
+  }
+
+  @Test
+  public void schemafulProvider_immutable() throws Exception {
+    StarlarkProvider.Key key =
+        new StarlarkProvider.Key(Label.parseCanonical("//foo:bar.bzl"), "prov");
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a"))
+            .setExported(key)
+            .build();
+    StarlarkInfo instance;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      instance =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of("a", StarlarkList.of(mu, "x")));
+    }
+
+    assertThat(instance.isImmutable()).isTrue();
+    @SuppressWarnings("unchecked")
+    StarlarkList<String> list = (StarlarkList<String>) instance.getValue("a");
+    assertThat((Iterable<?>) list).containsExactly("x");
+    // verifies the fields of the frozen provider instance are immutable
+    EvalException e = assertThrows(EvalException.class, () -> list.addElement("y"));
+    assertThat(e).hasMessageThat().contains("trying to mutate a frozen list value");
+  }
+
+  @Test
+  public void schemafulProviderWithDepset_isImmutable() throws Exception {
+    StarlarkProvider.Key key =
+        new StarlarkProvider.Key(Label.parseCanonical("//foo:bar.bzl"), "prov");
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a"))
+            .setExported(key)
+            .build();
+    StarlarkInfo instance;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      instance =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of(
+                      "a", Depset.of(String.class, NestedSetBuilder.create(STABLE_ORDER, "foo"))));
+
+      assertThat(instance.isImmutable()).isTrue();
+    }
+  }
+
+  @Test
+  public void schemafulProviderWithDepset_becomesImmutable() throws Exception {
+    StarlarkProvider.Key key =
+        new StarlarkProvider.Key(Label.parseCanonical("//foo:bar.bzl"), "prov");
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a", "b"))
+            .setExported(key)
+            .build();
+    StarlarkInfo instance;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      instance =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of(
+                      "a",
+                      Depset.of(String.class, NestedSetBuilder.create(STABLE_ORDER, "foo")),
+                      "b",
+                      StarlarkList.of(mu, "x")));
+
+      assertThat(instance.isImmutable()).isFalse();
+    }
+
+    assertThat(instance.isImmutable()).isTrue();
+  }
+
+  @Test
+  public void schemafulProvider_optimisedImmutable() throws Exception {
+    StarlarkProvider.Key key =
+        new StarlarkProvider.Key(Label.parseCanonical("//foo:bar.bzl"), "prov");
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a"))
+            .setExported(key)
+            .build();
+    StarlarkInfo instance;
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      instance =
+          (StarlarkInfo)
+              Starlark.call(
+                  thread,
+                  provider,
+                  /* args= */ ImmutableList.of(),
+                  /* kwargs= */ ImmutableMap.of("a", StarlarkList.of(mu, "x")));
+    }
+    instance = instance.unsafeOptimizeMemoryLayout();
+
+    assertThat(instance.isImmutable()).isTrue();
+    @SuppressWarnings("unchecked")
+    StarlarkList<String> list = (StarlarkList<String>) instance.getValue("a");
+    assertThat((Iterable<?>) list).containsExactly("x");
+
+    // verifies the fields of the frozen and optimised provider instance are immutable
+    EvalException e = assertThrows(EvalException.class, () -> list.addElement("y"));
+    assertThat(e).hasMessageThat().contains("trying to mutate a frozen list value");
+  }
+
   @Test
   public void providerEquals() throws Exception {
     // All permutations of differing label and differing name.