starlark: implement "flat frames" optimization

The resolver now statically maps each local variable to a small
integer, and the set of local bindings (including from
comprehensions) is accumulated into each Resolver.Function.

The evaluator uses an array instead of a LinkedHashMap to
represent the environment, and no longer needs a hack to
save/restore the values of comprehensions.

Also, added a test of the resolver.

PiperOrigin-RevId: 343529575
diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java
index e5fcc83..d650715 100644
--- a/src/main/java/net/starlark/java/eval/Eval.java
+++ b/src/main/java/net/starlark/java/eval/Eval.java
@@ -322,16 +322,17 @@
   private static void assignIdentifier(StarlarkThread.Frame fr, Identifier id, Object value)
       throws EvalException {
     Resolver.Binding bind = id.getBinding();
-    String name = id.getName();
     switch (bind.getScope()) {
       case LOCAL:
-        fr.locals.put(name, value);
+        fr.locals[bind.getIndex()] = value;
         break;
       case GLOBAL:
         // Updates a module binding and sets its 'exported' flag.
         // (Only load bindings are not exported.
         // But exportedGlobals does at run time what should be done in the resolver.)
+        // TODO(adonovan): use a flat array for Module.globals too.
         Module module = fn(fr).getModule();
+        String name = id.getName();
         module.setGlobal(name, value);
         module.exportedGlobals.add(name);
         break;
@@ -631,52 +632,26 @@
 
   private static Object evalIdentifier(StarlarkThread.Frame fr, Identifier id)
       throws EvalException, InterruptedException {
-    String name = id.getName();
     Resolver.Binding bind = id.getBinding();
-    if (bind == null) {
-      // Legacy behavior, to be removed.
-      Object result = fr.locals.get(name);
-      if (result != null) {
-        return result;
-      }
-      result = fn(fr).getModule().get(name);
-      if (result != null) {
-        return result;
-      }
-
-      // Assuming resolution was successfully applied before execution
-      // (which is not yet true for copybara, but will be soon),
-      // then the identifier must have been resolved but the
-      // resolution was not annotated onto the syntax tree---because
-      // it's a BUILD file that may share trees with the prelude.
-      // So this error does not mean "undefined variable" (morally a
-      // static error), but "variable was (dynamically) referenced
-      // before being bound", as in 'print(x); x=1'.
-      fr.setErrorLocation(id.getStartLocation());
-      throw Starlark.errorf("variable '%s' is referenced before assignment", name);
-    }
-
     Object result;
     switch (bind.getScope()) {
       case LOCAL:
-        result = fr.locals.get(name);
+        result = fr.locals[bind.getIndex()];
         break;
       case GLOBAL:
-        result = fn(fr).getModule().getGlobal(name);
+        result = fn(fr).getModule().getGlobal(id.getName());
         break;
       case PREDECLARED:
         // TODO(adonovan): call getPredeclared
-        result = fn(fr).getModule().get(name);
+        result = fn(fr).getModule().get(id.getName());
         break;
       default:
         throw new IllegalStateException(bind.toString());
     }
     if (result == null) {
-      // Since Scope was set, we know that the local/global variable is defined,
-      // but its assignment was not yet executed.
       fr.setErrorLocation(id.getStartLocation());
       throw Starlark.errorf(
-          "%s variable '%s' is referenced before assignment.", bind.getScope(), name);
+          "%s variable '%s' is referenced before assignment.", bind.getScope(), id.getName());
     }
     return result;
   }
@@ -734,22 +709,6 @@
     final StarlarkList<Object> list =
         comp.isDict() ? null : StarlarkList.of(fr.thread.mutability());
 
-    // Save previous value (if any) of local variables bound in a 'for' clause
-    // so we can restore them later.
-    // TODO(adonovan): throw all this away when we implement flat environments.
-    List<Object> saved = new ArrayList<>(); // alternating keys and values
-    for (Comprehension.Clause clause : comp.getClauses()) {
-      if (clause instanceof Comprehension.For) {
-        for (Identifier ident :
-            Identifier.boundIdentifiers(((Comprehension.For) clause).getVars())) {
-          String name = ident.getName();
-          Object value = fr.locals.get(ident.getName()); // may be null
-          saved.add(name);
-          saved.add(value);
-        }
-      }
-    }
-
     // The Lambda class serves as a recursive lambda closure.
     class Lambda {
       // execClauses(index) recursively executes the clauses starting at index,
@@ -806,18 +765,6 @@
     }
     new Lambda().execClauses(0);
 
-    // Restore outer scope variables.
-    // This loop implicitly undefines comprehension variables.
-    for (int i = 0; i != saved.size(); ) {
-      String name = (String) saved.get(i++);
-      Object value = saved.get(i++);
-      if (value != null) {
-        fr.locals.put(name, value);
-      } else {
-        fr.locals.remove(name);
-      }
-    }
-
     return comp.isDict() ? dict : list;
   }
 
diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
index c67f8e5..afd5077 100644
--- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java
+++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
@@ -35,7 +35,7 @@
     doc = "The type of functions declared in Starlark.")
 public final class StarlarkFunction implements StarlarkCallable {
 
-  private final Resolver.Function rfn;
+  final Resolver.Function rfn;
   private final Module module; // a function closes over its defining module
   private final Tuple defaultValues;
 
@@ -152,11 +152,8 @@
     Object[] arguments = processArgs(thread.mutability(), positional, named);
 
     StarlarkThread.Frame fr = thread.frame(0);
-    ImmutableList<String> names = rfn.getParameterNames();
-    for (int i = 0; i < names.size(); ++i) {
-      fr.locals.put(names.get(i), arguments[i]);
-    }
-
+    fr.locals = new Object[rfn.getLocals().size()];
+    System.arraycopy(arguments, 0, fr.locals, 0, rfn.getParameterNames().size());
     return Eval.execFunctionBody(fr, rfn.getBody());
   }
 
diff --git a/src/main/java/net/starlark/java/eval/StarlarkThread.java b/src/main/java/net/starlark/java/eval/StarlarkThread.java
index e859ec5..5019d00 100644
--- a/src/main/java/net/starlark/java/eval/StarlarkThread.java
+++ b/src/main/java/net/starlark/java/eval/StarlarkThread.java
@@ -17,7 +17,6 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Maps;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map;
@@ -136,8 +135,9 @@
     // location (loc) should not be overrwritten.
     private boolean errorLocationSet;
 
-    // The locals of this frame, if fn is a StarlarkFunction, otherwise empty.
-    Map<String, Object> locals;
+    // The locals of this frame, if fn is a StarlarkFunction, otherwise null.
+    // Set by StarlarkFunction.fastcall.
+    @Nullable Object[] locals;
 
     @Nullable private Object profileSpan; // current span of walltime call profiler
 
@@ -177,7 +177,16 @@
 
     @Override
     public ImmutableMap<String, Object> getLocals() {
-      return ImmutableMap.copyOf(this.locals);
+      // TODO(adonovan): provide a more efficient API.
+      ImmutableMap.Builder<String, Object> env = ImmutableMap.builder();
+      if (fn instanceof StarlarkFunction) {
+        for (int i = 0; i < locals.length; i++) {
+          if (locals[i] != null) {
+            env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), locals[i]);
+          }
+        }
+      }
+      return env.build();
     }
 
     @Override
@@ -214,14 +223,6 @@
       Debug.threadHook.onPushFirst(this);
     }
 
-    if (fn instanceof StarlarkFunction) {
-      StarlarkFunction sfn = (StarlarkFunction) fn;
-      fr.locals = Maps.newLinkedHashMapWithExpectedSize(sfn.getParameterNames().size());
-    } else {
-      // built-in function
-      fr.locals = ImmutableMap.of();
-    }
-
     fr.loc = fn.getLocation();
 
     // Start wall-time call profile span.
diff --git a/src/main/java/net/starlark/java/syntax/Resolver.java b/src/main/java/net/starlark/java/syntax/Resolver.java
index 6aa7d63..fe3c634 100644
--- a/src/main/java/net/starlark/java/syntax/Resolver.java
+++ b/src/main/java/net/starlark/java/syntax/Resolver.java
@@ -44,19 +44,30 @@
  */
 public final class Resolver extends NodeVisitor {
 
-  // TODO(adonovan): use "keyword" (not "named") and "required" (not "mandatory") terminology
-  // everywhere, including the spec.
+  // TODO(adonovan):
+  // - use "keyword" (not "named") and "required" (not "mandatory") terminology everywhere,
+  //   including the spec.
+  // - move the "no if statements at top level" check to bazel's check{Build,*}Syntax
+  //   (that's a spec change), or put it behind a FileOptions flag (no spec change).
+  // - remove restriction on nested def:
+  //   1. use FREE for scope of references to outer LOCALs, which become CELLs.
+  //   2. implement closures in eval/.
+  // - make loads bind locals by default.
 
   /** Scope discriminates the scope of a binding: global, local, etc. */
   public enum Scope {
-    // TODO(adonovan): Add UNIVERSAL, FREE, CELL.
-    // (PREDECLARED vs UNIVERSAL allows us to represent the app-dependent and fixed parts of the
-    // predeclared environment separately, reducing the amount of copying.)
+    // TODO(adonovan): Add UNIVERSAL. Separating PREDECLARED from UNIVERSAL allows
+    // us to represent the app-dependent and fixed parts of the predeclared
+    // environment separately, reducing the amount of copying.
 
     /** Binding is local to a function, comprehension, or file (e.g. load). */
     LOCAL,
     /** Binding occurs outside any function or comprehension. */
     GLOBAL,
+    /** Binding is local to a function, comprehension, or file, plus its nested functions. */
+    CELL, // TODO(adonovan): implement nested def
+    /** Binding is a CELL of some enclosing function. */
+    FREE, // TODO(adonovan): implement nested def
     /** Binding is predeclared by the core or application. */
     PREDECLARED;
 
@@ -74,7 +85,7 @@
 
     private final Scope scope;
     @Nullable private final Identifier first; // first binding use, if syntactic
-    private final int index; // within its block (currently unused)
+    final int index; // index within function (LOCAL) or module (GLOBAL)
 
     private Binding(Scope scope, @Nullable Identifier first, int index) {
       this.scope = scope;
@@ -82,11 +93,22 @@
       this.index = index;
     }
 
+    /** Returns the name of this binding's identifier. */
+    @Nullable
+    public String getName() {
+      return first != null ? first.getName() : null;
+    }
+
     /** Returns the scope of the binding. */
     public Scope getScope() {
       return scope;
     }
 
+    /** Returns the index of a binding within its function (LOCAL) or module (GLOBAL). */
+    public int getIndex() {
+      return index;
+    }
+
     @Override
     public String toString() {
       return first == null
@@ -108,6 +130,7 @@
     private final int numKeywordOnlyParams;
     private final ImmutableList<String> parameterNames;
     private final boolean isToplevel;
+    private final ImmutableList<Binding> locals;
 
     private Function(
         String name,
@@ -116,7 +139,8 @@
         ImmutableList<Statement> body,
         boolean hasVarargs,
         boolean hasKwargs,
-        int numKeywordOnlyParams) {
+        int numKeywordOnlyParams,
+        List<Binding> locals) {
       this.name = name;
       this.location = loc;
       this.params = params;
@@ -132,6 +156,7 @@
       this.parameterNames = names.build();
 
       this.isToplevel = name.equals("<toplevel>");
+      this.locals = ImmutableList.copyOf(locals);
     }
 
     /**
@@ -143,6 +168,11 @@
       return name;
     }
 
+    /** Returns the function's local bindings, parameters first. */
+    public ImmutableList<Binding> getLocals() {
+      return locals;
+    }
+
     /** Returns the location of the function's identifier. */
     public Location getLocation() {
       return location;
@@ -230,12 +260,14 @@
 
   private static class Block {
     private final Map<String, Binding> bindings = new HashMap<>();
+    private final ArrayList<Binding> locals; // of enclosing function
     private final Scope scope;
     @Nullable private final Block parent;
 
-    Block(Scope scope, @Nullable Block parent) {
+    Block(Scope scope, @Nullable Block parent, ArrayList<Binding> locals) {
       this.scope = scope;
       this.parent = parent;
+      this.locals = locals;
     }
   }
 
@@ -253,7 +285,7 @@
     this.module = module;
     this.options = options;
 
-    this.block = new Block(Scope.PREDECLARED, null);
+    this.block = new Block(Scope.PREDECLARED, /*parent=*/ null, /*locals=*/ null);
     for (String name : module.getNames()) {
       block.bindings.put(name, PREDECLARED);
     }
@@ -277,7 +309,7 @@
    * in order).
    */
   // TODO(adonovan): eliminate this first pass by using go.starlark.net one-pass approach.
-  private void createBindings(Iterable<Statement> stmts) {
+  private void createBindingsForBlock(Iterable<Statement> stmts) {
     for (Statement stmt : stmts) {
       createBindings(stmt);
     }
@@ -286,19 +318,19 @@
   private void createBindings(Statement stmt) {
     switch (stmt.kind()) {
       case ASSIGNMENT:
-        createBindings(((AssignmentStatement) stmt).getLHS());
+        createBindingsForLHS(((AssignmentStatement) stmt).getLHS());
         break;
       case IF:
         IfStatement ifStmt = (IfStatement) stmt;
-        createBindings(ifStmt.getThenBlock());
+        createBindingsForBlock(ifStmt.getThenBlock());
         if (ifStmt.getElseBlock() != null) {
-          createBindings(ifStmt.getElseBlock());
+          createBindingsForBlock(ifStmt.getElseBlock());
         }
         break;
       case FOR:
         ForStatement forStmt = (ForStatement) stmt;
-        createBindings(forStmt.getVars());
-        createBindings(forStmt.getBody());
+        createBindingsForLHS(forStmt.getVars());
+        createBindingsForBlock(forStmt.getBody());
         break;
       case DEF:
         DefStatement def = (DefStatement) stmt;
@@ -341,7 +373,7 @@
     }
   }
 
-  private void createBindings(Expression lhs) {
+  private void createBindingsForLHS(Expression lhs) {
     for (Identifier id : Identifier.boundIdentifiers(lhs)) {
       bind(id);
     }
@@ -502,11 +534,13 @@
     Comprehension.For for0 = (Comprehension.For) clauses.get(0);
     visit(for0.getIterable());
 
-    openBlock(Scope.LOCAL);
+    // A comprehension defines a distinct lexical block in the same function.
+    openBlock(Scope.LOCAL, this.block.locals);
+
     for (Comprehension.Clause clause : clauses) {
       if (clause instanceof Comprehension.For) {
         Comprehension.For forClause = (Comprehension.For) clause;
-        createBindings(forClause.getVars());
+        createBindingsForLHS(forClause.getVars());
       }
     }
     for (int i = 0; i < clauses.size(); i++) {
@@ -553,7 +587,8 @@
     }
 
     // Enter function block.
-    openBlock(Scope.LOCAL);
+    ArrayList<Binding> locals = new ArrayList<>();
+    openBlock(Scope.LOCAL, locals);
 
     // Check parameter order and convert to run-time order:
     // positionals, keyword-only, *args, **kwargs.
@@ -627,7 +662,7 @@
       bindParam(params, starStar);
     }
 
-    createBindings(body);
+    createBindingsForBlock(body);
     visitAll(body);
     closeBlock();
 
@@ -638,7 +673,8 @@
         body,
         star != null && star.getIdentifier() != null,
         starStar != null,
-        numKeywordOnlyParams);
+        numKeywordOnlyParams,
+        locals);
   }
 
   private void bindParam(ImmutableList.Builder<Parameter> params, Parameter param) {
@@ -700,8 +736,13 @@
     }
 
     // new binding
-    // TODO(adonovan): accumulate locals in the enclosing function/file block.
-    bind = new Binding(block.scope, id, block.bindings.size());
+    if (block.scope == Scope.LOCAL) {
+      // Accumulate local bindings in the enclosing function.
+      bind = new Binding(block.scope, id, block.locals.size());
+      block.locals.add(bind);
+    } else { // GLOBAL
+      bind = new Binding(block.scope, id, block.bindings.size());
+    }
     block.bindings.put(id.getName(), bind);
     id.setBinding(bind);
     return false;
@@ -741,35 +782,30 @@
     }
   }
 
-  private void resolveToplevelStatements(List<Statement> statements) {
-    // Check that load() statements are on top.
-    if (options.requireLoadStatementsFirst()) {
-      checkLoadAfterStatement(statements);
-    }
-
-    openBlock(Scope.GLOBAL);
-
-    // Add a binding for each variable defined by statements, not including definitions that appear
-    // in sub-scopes of the given statements (function bodies and comprehensions).
-    createBindings(statements);
-
-    // Second pass: ensure that all symbols have been defined.
-    visitAll(statements);
-    closeBlock();
-  }
-
   /**
    * Performs static checks, including resolution of identifiers in {@code file} in the environment
    * defined by {@code module}. The StarlarkFile is mutated. Errors are appended to {@link
    * StarlarkFile#errors}.
    */
   public static void resolveFile(StarlarkFile file, Module module) {
+    Resolver r = new Resolver(file.errors, module, file.getOptions());
+
+    ArrayList<Binding> locals = new ArrayList<>();
+    r.openBlock(Scope.GLOBAL, locals);
+
     ImmutableList<Statement> stmts = file.getStatements();
 
-    Resolver r = new Resolver(file.errors, module, file.getOptions());
-    r.resolveToplevelStatements(stmts);
-    // Check that no closeBlock was forgotten.
-    Preconditions.checkState(r.block.parent == null);
+    // Check that load() statements are on top.
+    if (r.options.requireLoadStatementsFirst()) {
+      r.checkLoadAfterStatement(stmts);
+    }
+
+    // First pass: creating bindings for statements in this block.
+    r.createBindingsForBlock(stmts);
+
+    // Second pass: visit all references.
+    r.visitAll(stmts);
+    r.closeBlock();
 
     // If the final statement is an expression, synthesize a return statement.
     int n = stmts.size();
@@ -791,7 +827,8 @@
             /*body=*/ stmts,
             /*hasVarargs=*/ false,
             /*hasKwargs=*/ false,
-            /*numKeywordOnlyParams=*/ 0));
+            /*numKeywordOnlyParams=*/ 0,
+            locals));
   }
 
   /**
@@ -804,7 +841,10 @@
     List<SyntaxError> errors = new ArrayList<>();
     Resolver r = new Resolver(errors, module, options);
 
+    ArrayList<Binding> locals = new ArrayList<>();
+    r.openBlock(Scope.LOCAL, locals); // for bindings in list comprehensions
     r.visit(expr);
+    r.closeBlock();
 
     if (!errors.isEmpty()) {
       throw new SyntaxError.Exception(errors);
@@ -818,12 +858,13 @@
         ImmutableList.of(ReturnStatement.make(expr)),
         /*hasVarargs=*/ false,
         /*hasKwargs=*/ false,
-        /*numKeywordOnlyParams=*/ 0);
+        /*numKeywordOnlyParams=*/ 0,
+        locals);
   }
 
-  /** Open a new lexical block that will contain the future declarations. */
-  private void openBlock(Scope scope) {
-    block = new Block(scope, block);
+  /** Open a new lexical block for future bindings. Local bindings will be added to locals. */
+  private void openBlock(Scope scope, ArrayList<Binding> locals) {
+    block = new Block(scope, block, locals);
   }
 
   /** Close a lexical block (and lose all declarations it contained). */
diff --git a/src/test/java/net/starlark/java/eval/BUILD b/src/test/java/net/starlark/java/eval/BUILD
index 62579c9..4d0972d 100644
--- a/src/test/java/net/starlark/java/eval/BUILD
+++ b/src/test/java/net/starlark/java/eval/BUILD
@@ -30,6 +30,7 @@
         "StarlarkThreadDebuggingTest.java",
         "StarlarkThreadTest.java",
     ],
+    jvm_flags = ["-Dfile.encoding=UTF8"],
     deps = [
         "//src/main/java/net/starlark/java/annot",
         "//src/main/java/net/starlark/java/eval",
diff --git a/src/test/java/net/starlark/java/eval/testdata/assign.star b/src/test/java/net/starlark/java/eval/testdata/assign.star
new file mode 100644
index 0000000..87076e7
--- /dev/null
+++ b/src/test/java/net/starlark/java/eval/testdata/assign.star
@@ -0,0 +1,6 @@
+# tests of assignment
+
+# computation in a[...]=x expression.
+a = [0, 1, 2, 3, 4, 5]
+a[[i for i in range(6) if i == 2][0]] = "z"
+assert_eq(a, [0, 1, "z", 3, 4, 5])
diff --git a/src/test/java/net/starlark/java/syntax/ResolverTest.java b/src/test/java/net/starlark/java/syntax/ResolverTest.java
index 979fe30..6df2f2e 100644
--- a/src/test/java/net/starlark/java/syntax/ResolverTest.java
+++ b/src/test/java/net/starlark/java/syntax/ResolverTest.java
@@ -13,8 +13,10 @@
 // limitations under the License.
 package net.starlark.java.syntax;
 
+import static com.google.common.truth.Truth.assertThat;
 import static net.starlark.java.syntax.LexerTest.assertContainsError;
 
+import com.google.common.base.Joiner;
 import com.google.common.collect.ImmutableSet;
 import java.util.List;
 import org.junit.Test;
@@ -385,4 +387,54 @@
 
     assertValid("pre(0, a=0, *0, **0)");
   }
+
+  @Test
+  public void testBindingScopeAndIndex() throws Exception {
+    checkBindings(
+        "xᴳ₀ = 0", //
+        "yᴳ₁ = 1",
+        "zᴳ₂ = 2",
+        "xᴳ₀(xᴳ₀, yᴳ₁, preᴾ₀)",
+        "[xᴸ₀ for xᴸ₀ in xᴳ₀ if yᴳ₁]",
+        "def fᴳ₃(xᴸ₀ = xᴳ₀):",
+        "  xᴸ₀ = yᴸ₁",
+        "  yᴸ₁ = zᴳ₂");
+
+    // Note: loads bind globally, for now.
+    checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')");
+
+    checkBindings(
+        "aᴳ₀, bᴳ₁ = 0, 0", //
+        "def fᴳ₂(aᴸ₀=bᴳ₁):",
+        "  aᴸ₀, bᴳ₁",
+        "  [(aᴸ₁, bᴳ₁) for aᴸ₁ in aᴸ₀]");
+
+    checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')");
+  }
+
+  // checkBindings verifies the binding (scope and index) of each identifier.
+  // Every variable must be followed by a superscript letter (its scope)
+  // and a subscript numeral (its index). They are replaced by spaces, the
+  // file is resolved, and then the computed information is written over
+  // the spaces. The resulting string must match the input.
+  private void checkBindings(String... lines) throws Exception {
+    String src = Joiner.on("\n").join(lines);
+    StarlarkFile file = resolveFile(src.replaceAll("[₀₁₂₃₄₅₆₇₈₉ᴳᴸᴾᶠᶜ]", " "));
+    if (!file.ok()) {
+      throw new AssertionError("resolution failed: " + file.errors());
+    }
+    String[] out = new String[] {src};
+    new NodeVisitor() {
+      @Override
+      public void visit(Identifier id) {
+        // Replace ...x__... with ...xᴸ₀...
+        out[0] =
+            out[0].substring(0, id.getEndOffset())
+                + "ᴸᴳᶜᶠᴾ".charAt(id.getBinding().getScope().ordinal()) // follow order of enum
+                + "₀₁₂₃₄₅₆₇₈₉".charAt(id.getBinding().index) // 10 is plenty
+                + out[0].substring(id.getEndOffset() + 2);
+      }
+    }.visit(file);
+    assertThat(out[0]).isEqualTo(src);
+  }
 }