Expose global variables' static types in Module

This makes it possible for users of the module to rely on static type
information (since the dynamic type of a global may be too narrow to
safely rely on).

Working towards #27370.

PiperOrigin-RevId: 912856638
Change-Id: I94fdd857142283c4b9ba7a47dfd8ba42c0c166e1
diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java
index 7203c47..e83cd32 100644
--- a/src/main/java/net/starlark/java/eval/Eval.java
+++ b/src/main/java/net/starlark/java/eval/Eval.java
@@ -51,6 +51,7 @@
 import net.starlark.java.syntax.Statement;
 import net.starlark.java.syntax.StringLiteral;
 import net.starlark.java.syntax.TokenKind;
+import net.starlark.java.syntax.TypeTable;
 import net.starlark.java.syntax.Types.CallableType;
 import net.starlark.java.syntax.UnaryOperatorExpression;
 
@@ -373,17 +374,17 @@
   private static void assignIdentifier(StarlarkThread.Frame fr, Identifier id, Object value) {
     Resolver.Binding bind = id.getBinding();
     switch (bind.getScope()) {
-      case LOCAL:
-        fr.locals[bind.getIndex()] = value;
-        break;
-      case CELL:
-        ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x = value;
-        break;
-      case GLOBAL:
-        fn(fr).setGlobal(bind.getIndex(), value);
-        break;
-      default:
-        throw new IllegalStateException(bind.getScope().toString());
+      case LOCAL -> fr.locals[bind.getIndex()] = value;
+      case CELL -> ((StarlarkFunction.Cell) fr.locals[bind.getIndex()]).x = value;
+      case GLOBAL -> {
+        StarlarkFunction fn = fn(fr);
+        fn.setGlobal(bind.getIndex(), value);
+        @Nullable TypeTable typeTable = fn.getTypeTable();
+        if (typeTable != null) {
+          fn.setGlobalDeclaredType(bind.getIndex(), typeTable.getGlobalDeclaredType(bind));
+        }
+      }
+      default -> throw new IllegalStateException(bind.getScope().toString());
     }
   }
 
diff --git a/src/main/java/net/starlark/java/eval/Module.java b/src/main/java/net/starlark/java/eval/Module.java
index df0c828..55c8455 100644
--- a/src/main/java/net/starlark/java/eval/Module.java
+++ b/src/main/java/net/starlark/java/eval/Module.java
@@ -14,6 +14,9 @@
 
 package net.starlark.java.eval;
 
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
@@ -60,6 +63,10 @@
   // The module's global variables, in order of creation.
   private final LinkedHashMap<String, Integer> globalIndex = new LinkedHashMap<>();
   private Object[] globals = new Object[8];
+  // The module's exported global variables' types. Null if type checking is not enabled for this
+  // module. Otherwise, has the same length and same order as {@link #globals}.  Intended for use by
+  // other modules which load this.
+  @Nullable private StarlarkType[] globalsTypes;
 
   // An optional piece of application-specific metadata associated with the module/file.
   // Its toString appears to Starlark in str(function): "<function f from ...>".
@@ -309,6 +316,22 @@
   }
 
   /**
+   * Returns the exported Starlark type of the specified global variable; intended for use by other
+   * modules that load this module (not by the evaluation of this module itself).
+   *
+   * <p>If type checking was enabled for this module, returns the variable's declared static type if
+   * there is one; or the variable's value's dynamic type otherwise.
+   *
+   * <p>If type checking was not enabled for this module (or if the global variable does not exist),
+   * returns null.
+   */
+  @Nullable
+  public StarlarkType getGlobalType(String name) {
+    Integer i = globalIndex.get(name);
+    return i != null ? getGlobalTypeByIndex(i) : null;
+  }
+
+  /**
    * Sets the value of a global variable based on its index in this module ({@see
    * getIndexOfGlobal}).
    */
@@ -318,8 +341,8 @@
   }
 
   /**
-   * Returns the value of a global variable based on its index in this module ({@see
-   * getIndexOfGlobal}.) Returns null if the variable has not been assigned a value.
+   * Returns the value of a global variable based on its index in this module (see {@link
+   * #getIndexOfGlobal}.) Returns null if the variable has not been assigned a value.
    */
   @Nullable
   Object getGlobalByIndex(int i) {
@@ -328,6 +351,29 @@
   }
 
   /**
+   * Returns the value of a global variable based on its index in this module (see {@link
+   * #getIndexOfGlobal}.) Returns null if the variable has not been assigned an exported type (in
+   * particular, if type checking is not enabled).
+   */
+  @Nullable
+  StarlarkType getGlobalTypeByIndex(int i) {
+    Preconditions.checkArgument(i < globalIndex.size());
+    return globalsTypes != null ? globalsTypes[i] : null;
+  }
+
+  /**
+   * Sets the exported type of a global variable based on its index in this module (see {@link
+   * #getIndexOfGlobal}.)
+   */
+  void setGlobalTypeByIndex(int i, StarlarkType type) {
+    Preconditions.checkArgument(i < globalIndex.size());
+    if (globalsTypes == null) {
+      globalsTypes = new StarlarkType[globals.length];
+    }
+    globalsTypes[i] = type;
+  }
+
+  /**
    * Returns the index within this Module of a global variable, given its name, creating a new slot
    * for it if needed. The numbering of globals used by these functions is not the same as the
    * numbering within any compiled Program. Thus each StarlarkFunction must contain a secondary
@@ -340,7 +386,12 @@
       return prev;
     }
     if (i == globals.length) {
-      globals = Arrays.copyOf(globals, globals.length << 1); // grow by doubling
+      // grow by doubling
+      checkState(globalsTypes == null || globals.length == globalsTypes.length);
+      globals = Arrays.copyOf(globals, globals.length << 1);
+      if (globalsTypes != null) {
+        globalsTypes = Arrays.copyOf(globalsTypes, globalsTypes.length << 1);
+      }
     }
     return i;
   }
@@ -360,10 +411,31 @@
     return array;
   }
 
-  /** Updates a global binding in the module environment. */
-  public void setGlobal(String name, Object value) {
+  /**
+   * Updates a global binding and (optionally) its declared type in the module environment.
+   *
+   * <p>Intended only for use by tests.
+   *
+   * @param declaredType if non-null, the declared type to set for the global; ignored if null.
+   */
+  @VisibleForTesting
+  public void setGlobal(String name, Object value, @Nullable StarlarkType declaredType) {
     Preconditions.checkNotNull(value, "Module.setGlobal(%s, null)", name);
-    setGlobalByIndex(getIndexOfGlobal(name), value);
+    int index = getIndexOfGlobal(name);
+    setGlobalByIndex(index, value);
+    if (declaredType != null) {
+      setGlobalTypeByIndex(index, declaredType);
+    }
+  }
+
+  /**
+   * Updates a global binding in the module environment, without altering its static type.
+   *
+   * <p>Intended only for use by tests.
+   */
+  @VisibleForTesting
+  public void setGlobal(String name, Object value) {
+    setGlobal(name, value, null);
   }
 
   @Override
diff --git a/src/main/java/net/starlark/java/eval/Starlark.java b/src/main/java/net/starlark/java/eval/Starlark.java
index 708013a..910ffb8 100644
--- a/src/main/java/net/starlark/java/eval/Starlark.java
+++ b/src/main/java/net/starlark/java/eval/Starlark.java
@@ -1242,8 +1242,8 @@
    * in which case its value is returned.
    *
    * <p>This method does not perform type tagging or static type checking. If type tagging or type
-   * checking is needed, first use {@link #typeTagAndStaticTypeCheck} to obtain a
-   * type-tagged/checked version of {@code prog}.
+   * checking is needed, first use {@link #withTypeInfo} to obtain a type-tagged/checked version of
+   * {@code prog}.
    *
    * @throws EvalException if there was a (dynamic) evaluation error.
    * @throws InterruptedException if the Java thread was interrupted during evaluation.
@@ -1280,7 +1280,30 @@
             /* defaultValues= */ Tuple.empty(),
             /* freevars= */ Tuple.empty(),
             thread.getNextIdentityToken());
-    return Starlark.positionalOnlyCall(thread, toplevel);
+    Object result = Starlark.positionalOnlyCall(thread, toplevel);
+    if (prog.getTypeTable() != null) {
+      // For globals that don't have a declared static type, we export the value's dynamic type.
+      // We export the dynamic type of the value (rather than the inferred static type) because it's
+      // likely to be more useful to users who load() this module; they would want to type-check
+      // on the real set of fields of a Bazel struct or provider, or the real named args to a rule
+      // or macro. A module can annotate a global with a wider type to avoid exposing the dynamic
+      // type as part of its API.
+      //
+      // Exporting the dynamic type does result in one wart: the exported type might not be a
+      // subtype of the inferred static type, due to the invariance rule for mutable collections.
+      // For example, we might statically infer global X to be list[int|float] and export its
+      // value's dynamic type as list[int] - but list[int] is not a subtype of list[int|float].
+      // Since the exported values are frozen, it may be possible to fix this wart by introducing
+      // frozenlist, frozendict, etc.
+      // TODO: #27370 - Ensure this mechanism works for REPL.
+      for (int i : globalIndex) {
+        Object value = module.getGlobalByIndex(i);
+        if (value != null && module.getGlobalTypeByIndex(i) == null) {
+          module.setGlobalTypeByIndex(i, Starlark.getStarlarkType(value));
+        }
+      }
+    }
+    return result;
   }
 
   /**
diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
index 1b25422..8d581a3 100644
--- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java
+++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
@@ -83,6 +83,10 @@
     module.setGlobalByIndex(globalIndex[progIndex], value);
   }
 
+  void setGlobalDeclaredType(int progIndex, StarlarkType type) {
+    module.setGlobalTypeByIndex(globalIndex[progIndex], type);
+  }
+
   // Gets the value of a global variable, given its index in this function's compiled Program.
   @Nullable
   Object getGlobal(int progIndex) {
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/BUILD b/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
index c646cd1..c3455dc 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
@@ -2109,6 +2109,7 @@
         "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//src/main/java/com/google/devtools/common/options",
         "//src/main/java/net/starlark/java/eval",
+        "//src/main/java/net/starlark/java/syntax",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
         "//src/test/java/com/google/devtools/build/lib/bazel/bzlmod:util",
         "//src/test/java/com/google/devtools/build/lib/skyframe/util:SkyframeExecutorTestUtils",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/BzlLoadFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/BzlLoadFunctionTest.java
index 3e209ab..9d67865 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/BzlLoadFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/BzlLoadFunctionTest.java
@@ -52,6 +52,7 @@
 import java.util.UUID;
 import javax.annotation.Nullable;
 import net.starlark.java.eval.StarlarkInt;
+import net.starlark.java.syntax.Types;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -1202,13 +1203,15 @@
     setBuildLanguageOptions(
         "--experimental_starlark_type_syntax", "--experimental_starlark_static_type_checking");
     scratch.file("a/BUILD");
-    scratch.file("a/foo.bzl", "x: list[int] = [1, 2, 3]");
+    scratch.file("a/foo.bzl", "x: list[int]|list[str] = [1, 2, 3]");
     SkyKey key = key("//a:foo.bzl");
 
     EvaluationResult<BzlLoadValue> result =
         SkyframeExecutorTestUtils.evaluate(
             getSkyframeExecutor(), key, /* keepGoing= */ false, reporter);
     assertThatEvaluationResult(result).hasNoError();
+    assertThat(result.get(key).getModule().getGlobalType("x"))
+        .isEqualTo(Types.union(Types.list(Types.INT), Types.list(Types.STR)));
   }
 
   @Test
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
index 68c064f..80410c0 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/BUILD
@@ -317,6 +317,7 @@
         "//src/main/java/com/google/devtools/build/lib/skyframe/serialization/testutils:round-tripping",
         "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//src/main/java/net/starlark/java/eval",
+        "//src/main/java/net/starlark/java/syntax",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
         "//third_party:guava",
         "//third_party:jsr305",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ModuleCodecTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ModuleCodecTest.java
index 188c47d..e655ccd 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ModuleCodecTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/serialization/ModuleCodecTest.java
@@ -34,6 +34,7 @@
 import javax.annotation.Nullable;
 import net.starlark.java.eval.Module;
 import net.starlark.java.eval.StarlarkSemantics;
+import net.starlark.java.syntax.Types;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -51,7 +52,13 @@
     subject2.setGlobal("x", 1);
     subject2.setGlobal("y", 2);
 
-    new SerializationTester(subject1, subject2)
+    Module subject3 =
+        Module.withPredeclaredAndData(
+            StarlarkSemantics.DEFAULT, ImmutableMap.of(), Label.parseCanonical("//foo:bar"));
+    subject3.setGlobal("x", 1, Types.INT);
+    subject3.setGlobal("y", 2, Types.ANY);
+
+    new SerializationTester(subject1, subject2, subject3)
         .makeMemoizing()
         .setVerificationFunction(ModuleCodecTest::verifyDeserialization)
         .runTestsWithoutStableSerializationCheck();