starlark: implement "flat frames" optimization
The resolver now statically maps each local variable to a small
integer, and the set of local bindings (including from
comprehensions) is accumulated into each Resolver.Function.
The evaluator uses an array instead of a LinkedHashMap to
represent the environment, and no longer needs a hack to
save/restore the values of comprehensions.
Also, added a test of the resolver.
PiperOrigin-RevId: 343529575
diff --git a/src/main/java/net/starlark/java/eval/Eval.java b/src/main/java/net/starlark/java/eval/Eval.java
index e5fcc83..d650715 100644
--- a/src/main/java/net/starlark/java/eval/Eval.java
+++ b/src/main/java/net/starlark/java/eval/Eval.java
@@ -322,16 +322,17 @@
private static void assignIdentifier(StarlarkThread.Frame fr, Identifier id, Object value)
throws EvalException {
Resolver.Binding bind = id.getBinding();
- String name = id.getName();
switch (bind.getScope()) {
case LOCAL:
- fr.locals.put(name, value);
+ fr.locals[bind.getIndex()] = value;
break;
case GLOBAL:
// Updates a module binding and sets its 'exported' flag.
// (Only load bindings are not exported.
// But exportedGlobals does at run time what should be done in the resolver.)
+ // TODO(adonovan): use a flat array for Module.globals too.
Module module = fn(fr).getModule();
+ String name = id.getName();
module.setGlobal(name, value);
module.exportedGlobals.add(name);
break;
@@ -631,52 +632,26 @@
private static Object evalIdentifier(StarlarkThread.Frame fr, Identifier id)
throws EvalException, InterruptedException {
- String name = id.getName();
Resolver.Binding bind = id.getBinding();
- if (bind == null) {
- // Legacy behavior, to be removed.
- Object result = fr.locals.get(name);
- if (result != null) {
- return result;
- }
- result = fn(fr).getModule().get(name);
- if (result != null) {
- return result;
- }
-
- // Assuming resolution was successfully applied before execution
- // (which is not yet true for copybara, but will be soon),
- // then the identifier must have been resolved but the
- // resolution was not annotated onto the syntax tree---because
- // it's a BUILD file that may share trees with the prelude.
- // So this error does not mean "undefined variable" (morally a
- // static error), but "variable was (dynamically) referenced
- // before being bound", as in 'print(x); x=1'.
- fr.setErrorLocation(id.getStartLocation());
- throw Starlark.errorf("variable '%s' is referenced before assignment", name);
- }
-
Object result;
switch (bind.getScope()) {
case LOCAL:
- result = fr.locals.get(name);
+ result = fr.locals[bind.getIndex()];
break;
case GLOBAL:
- result = fn(fr).getModule().getGlobal(name);
+ result = fn(fr).getModule().getGlobal(id.getName());
break;
case PREDECLARED:
// TODO(adonovan): call getPredeclared
- result = fn(fr).getModule().get(name);
+ result = fn(fr).getModule().get(id.getName());
break;
default:
throw new IllegalStateException(bind.toString());
}
if (result == null) {
- // Since Scope was set, we know that the local/global variable is defined,
- // but its assignment was not yet executed.
fr.setErrorLocation(id.getStartLocation());
throw Starlark.errorf(
- "%s variable '%s' is referenced before assignment.", bind.getScope(), name);
+ "%s variable '%s' is referenced before assignment.", bind.getScope(), id.getName());
}
return result;
}
@@ -734,22 +709,6 @@
final StarlarkList<Object> list =
comp.isDict() ? null : StarlarkList.of(fr.thread.mutability());
- // Save previous value (if any) of local variables bound in a 'for' clause
- // so we can restore them later.
- // TODO(adonovan): throw all this away when we implement flat environments.
- List<Object> saved = new ArrayList<>(); // alternating keys and values
- for (Comprehension.Clause clause : comp.getClauses()) {
- if (clause instanceof Comprehension.For) {
- for (Identifier ident :
- Identifier.boundIdentifiers(((Comprehension.For) clause).getVars())) {
- String name = ident.getName();
- Object value = fr.locals.get(ident.getName()); // may be null
- saved.add(name);
- saved.add(value);
- }
- }
- }
-
// The Lambda class serves as a recursive lambda closure.
class Lambda {
// execClauses(index) recursively executes the clauses starting at index,
@@ -806,18 +765,6 @@
}
new Lambda().execClauses(0);
- // Restore outer scope variables.
- // This loop implicitly undefines comprehension variables.
- for (int i = 0; i != saved.size(); ) {
- String name = (String) saved.get(i++);
- Object value = saved.get(i++);
- if (value != null) {
- fr.locals.put(name, value);
- } else {
- fr.locals.remove(name);
- }
- }
-
return comp.isDict() ? dict : list;
}
diff --git a/src/main/java/net/starlark/java/eval/StarlarkFunction.java b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
index c67f8e5..afd5077 100644
--- a/src/main/java/net/starlark/java/eval/StarlarkFunction.java
+++ b/src/main/java/net/starlark/java/eval/StarlarkFunction.java
@@ -35,7 +35,7 @@
doc = "The type of functions declared in Starlark.")
public final class StarlarkFunction implements StarlarkCallable {
- private final Resolver.Function rfn;
+ final Resolver.Function rfn;
private final Module module; // a function closes over its defining module
private final Tuple defaultValues;
@@ -152,11 +152,8 @@
Object[] arguments = processArgs(thread.mutability(), positional, named);
StarlarkThread.Frame fr = thread.frame(0);
- ImmutableList<String> names = rfn.getParameterNames();
- for (int i = 0; i < names.size(); ++i) {
- fr.locals.put(names.get(i), arguments[i]);
- }
-
+ fr.locals = new Object[rfn.getLocals().size()];
+ System.arraycopy(arguments, 0, fr.locals, 0, rfn.getParameterNames().size());
return Eval.execFunctionBody(fr, rfn.getBody());
}
diff --git a/src/main/java/net/starlark/java/eval/StarlarkThread.java b/src/main/java/net/starlark/java/eval/StarlarkThread.java
index e859ec5..5019d00 100644
--- a/src/main/java/net/starlark/java/eval/StarlarkThread.java
+++ b/src/main/java/net/starlark/java/eval/StarlarkThread.java
@@ -17,7 +17,6 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
@@ -136,8 +135,9 @@
// location (loc) should not be overrwritten.
private boolean errorLocationSet;
- // The locals of this frame, if fn is a StarlarkFunction, otherwise empty.
- Map<String, Object> locals;
+ // The locals of this frame, if fn is a StarlarkFunction, otherwise null.
+ // Set by StarlarkFunction.fastcall.
+ @Nullable Object[] locals;
@Nullable private Object profileSpan; // current span of walltime call profiler
@@ -177,7 +177,16 @@
@Override
public ImmutableMap<String, Object> getLocals() {
- return ImmutableMap.copyOf(this.locals);
+ // TODO(adonovan): provide a more efficient API.
+ ImmutableMap.Builder<String, Object> env = ImmutableMap.builder();
+ if (fn instanceof StarlarkFunction) {
+ for (int i = 0; i < locals.length; i++) {
+ if (locals[i] != null) {
+ env.put(((StarlarkFunction) fn).rfn.getLocals().get(i).getName(), locals[i]);
+ }
+ }
+ }
+ return env.build();
}
@Override
@@ -214,14 +223,6 @@
Debug.threadHook.onPushFirst(this);
}
- if (fn instanceof StarlarkFunction) {
- StarlarkFunction sfn = (StarlarkFunction) fn;
- fr.locals = Maps.newLinkedHashMapWithExpectedSize(sfn.getParameterNames().size());
- } else {
- // built-in function
- fr.locals = ImmutableMap.of();
- }
-
fr.loc = fn.getLocation();
// Start wall-time call profile span.
diff --git a/src/main/java/net/starlark/java/syntax/Resolver.java b/src/main/java/net/starlark/java/syntax/Resolver.java
index 6aa7d63..fe3c634 100644
--- a/src/main/java/net/starlark/java/syntax/Resolver.java
+++ b/src/main/java/net/starlark/java/syntax/Resolver.java
@@ -44,19 +44,30 @@
*/
public final class Resolver extends NodeVisitor {
- // TODO(adonovan): use "keyword" (not "named") and "required" (not "mandatory") terminology
- // everywhere, including the spec.
+ // TODO(adonovan):
+ // - use "keyword" (not "named") and "required" (not "mandatory") terminology everywhere,
+ // including the spec.
+ // - move the "no if statements at top level" check to bazel's check{Build,*}Syntax
+ // (that's a spec change), or put it behind a FileOptions flag (no spec change).
+ // - remove restriction on nested def:
+ // 1. use FREE for scope of references to outer LOCALs, which become CELLs.
+ // 2. implement closures in eval/.
+ // - make loads bind locals by default.
/** Scope discriminates the scope of a binding: global, local, etc. */
public enum Scope {
- // TODO(adonovan): Add UNIVERSAL, FREE, CELL.
- // (PREDECLARED vs UNIVERSAL allows us to represent the app-dependent and fixed parts of the
- // predeclared environment separately, reducing the amount of copying.)
+ // TODO(adonovan): Add UNIVERSAL. Separating PREDECLARED from UNIVERSAL allows
+ // us to represent the app-dependent and fixed parts of the predeclared
+ // environment separately, reducing the amount of copying.
/** Binding is local to a function, comprehension, or file (e.g. load). */
LOCAL,
/** Binding occurs outside any function or comprehension. */
GLOBAL,
+ /** Binding is local to a function, comprehension, or file, plus its nested functions. */
+ CELL, // TODO(adonovan): implement nested def
+ /** Binding is a CELL of some enclosing function. */
+ FREE, // TODO(adonovan): implement nested def
/** Binding is predeclared by the core or application. */
PREDECLARED;
@@ -74,7 +85,7 @@
private final Scope scope;
@Nullable private final Identifier first; // first binding use, if syntactic
- private final int index; // within its block (currently unused)
+ final int index; // index within function (LOCAL) or module (GLOBAL)
private Binding(Scope scope, @Nullable Identifier first, int index) {
this.scope = scope;
@@ -82,11 +93,22 @@
this.index = index;
}
+ /** Returns the name of this binding's identifier. */
+ @Nullable
+ public String getName() {
+ return first != null ? first.getName() : null;
+ }
+
/** Returns the scope of the binding. */
public Scope getScope() {
return scope;
}
+ /** Returns the index of a binding within its function (LOCAL) or module (GLOBAL). */
+ public int getIndex() {
+ return index;
+ }
+
@Override
public String toString() {
return first == null
@@ -108,6 +130,7 @@
private final int numKeywordOnlyParams;
private final ImmutableList<String> parameterNames;
private final boolean isToplevel;
+ private final ImmutableList<Binding> locals;
private Function(
String name,
@@ -116,7 +139,8 @@
ImmutableList<Statement> body,
boolean hasVarargs,
boolean hasKwargs,
- int numKeywordOnlyParams) {
+ int numKeywordOnlyParams,
+ List<Binding> locals) {
this.name = name;
this.location = loc;
this.params = params;
@@ -132,6 +156,7 @@
this.parameterNames = names.build();
this.isToplevel = name.equals("<toplevel>");
+ this.locals = ImmutableList.copyOf(locals);
}
/**
@@ -143,6 +168,11 @@
return name;
}
+ /** Returns the function's local bindings, parameters first. */
+ public ImmutableList<Binding> getLocals() {
+ return locals;
+ }
+
/** Returns the location of the function's identifier. */
public Location getLocation() {
return location;
@@ -230,12 +260,14 @@
private static class Block {
private final Map<String, Binding> bindings = new HashMap<>();
+ private final ArrayList<Binding> locals; // of enclosing function
private final Scope scope;
@Nullable private final Block parent;
- Block(Scope scope, @Nullable Block parent) {
+ Block(Scope scope, @Nullable Block parent, ArrayList<Binding> locals) {
this.scope = scope;
this.parent = parent;
+ this.locals = locals;
}
}
@@ -253,7 +285,7 @@
this.module = module;
this.options = options;
- this.block = new Block(Scope.PREDECLARED, null);
+ this.block = new Block(Scope.PREDECLARED, /*parent=*/ null, /*locals=*/ null);
for (String name : module.getNames()) {
block.bindings.put(name, PREDECLARED);
}
@@ -277,7 +309,7 @@
* in order).
*/
// TODO(adonovan): eliminate this first pass by using go.starlark.net one-pass approach.
- private void createBindings(Iterable<Statement> stmts) {
+ private void createBindingsForBlock(Iterable<Statement> stmts) {
for (Statement stmt : stmts) {
createBindings(stmt);
}
@@ -286,19 +318,19 @@
private void createBindings(Statement stmt) {
switch (stmt.kind()) {
case ASSIGNMENT:
- createBindings(((AssignmentStatement) stmt).getLHS());
+ createBindingsForLHS(((AssignmentStatement) stmt).getLHS());
break;
case IF:
IfStatement ifStmt = (IfStatement) stmt;
- createBindings(ifStmt.getThenBlock());
+ createBindingsForBlock(ifStmt.getThenBlock());
if (ifStmt.getElseBlock() != null) {
- createBindings(ifStmt.getElseBlock());
+ createBindingsForBlock(ifStmt.getElseBlock());
}
break;
case FOR:
ForStatement forStmt = (ForStatement) stmt;
- createBindings(forStmt.getVars());
- createBindings(forStmt.getBody());
+ createBindingsForLHS(forStmt.getVars());
+ createBindingsForBlock(forStmt.getBody());
break;
case DEF:
DefStatement def = (DefStatement) stmt;
@@ -341,7 +373,7 @@
}
}
- private void createBindings(Expression lhs) {
+ private void createBindingsForLHS(Expression lhs) {
for (Identifier id : Identifier.boundIdentifiers(lhs)) {
bind(id);
}
@@ -502,11 +534,13 @@
Comprehension.For for0 = (Comprehension.For) clauses.get(0);
visit(for0.getIterable());
- openBlock(Scope.LOCAL);
+ // A comprehension defines a distinct lexical block in the same function.
+ openBlock(Scope.LOCAL, this.block.locals);
+
for (Comprehension.Clause clause : clauses) {
if (clause instanceof Comprehension.For) {
Comprehension.For forClause = (Comprehension.For) clause;
- createBindings(forClause.getVars());
+ createBindingsForLHS(forClause.getVars());
}
}
for (int i = 0; i < clauses.size(); i++) {
@@ -553,7 +587,8 @@
}
// Enter function block.
- openBlock(Scope.LOCAL);
+ ArrayList<Binding> locals = new ArrayList<>();
+ openBlock(Scope.LOCAL, locals);
// Check parameter order and convert to run-time order:
// positionals, keyword-only, *args, **kwargs.
@@ -627,7 +662,7 @@
bindParam(params, starStar);
}
- createBindings(body);
+ createBindingsForBlock(body);
visitAll(body);
closeBlock();
@@ -638,7 +673,8 @@
body,
star != null && star.getIdentifier() != null,
starStar != null,
- numKeywordOnlyParams);
+ numKeywordOnlyParams,
+ locals);
}
private void bindParam(ImmutableList.Builder<Parameter> params, Parameter param) {
@@ -700,8 +736,13 @@
}
// new binding
- // TODO(adonovan): accumulate locals in the enclosing function/file block.
- bind = new Binding(block.scope, id, block.bindings.size());
+ if (block.scope == Scope.LOCAL) {
+ // Accumulate local bindings in the enclosing function.
+ bind = new Binding(block.scope, id, block.locals.size());
+ block.locals.add(bind);
+ } else { // GLOBAL
+ bind = new Binding(block.scope, id, block.bindings.size());
+ }
block.bindings.put(id.getName(), bind);
id.setBinding(bind);
return false;
@@ -741,35 +782,30 @@
}
}
- private void resolveToplevelStatements(List<Statement> statements) {
- // Check that load() statements are on top.
- if (options.requireLoadStatementsFirst()) {
- checkLoadAfterStatement(statements);
- }
-
- openBlock(Scope.GLOBAL);
-
- // Add a binding for each variable defined by statements, not including definitions that appear
- // in sub-scopes of the given statements (function bodies and comprehensions).
- createBindings(statements);
-
- // Second pass: ensure that all symbols have been defined.
- visitAll(statements);
- closeBlock();
- }
-
/**
* Performs static checks, including resolution of identifiers in {@code file} in the environment
* defined by {@code module}. The StarlarkFile is mutated. Errors are appended to {@link
* StarlarkFile#errors}.
*/
public static void resolveFile(StarlarkFile file, Module module) {
+ Resolver r = new Resolver(file.errors, module, file.getOptions());
+
+ ArrayList<Binding> locals = new ArrayList<>();
+ r.openBlock(Scope.GLOBAL, locals);
+
ImmutableList<Statement> stmts = file.getStatements();
- Resolver r = new Resolver(file.errors, module, file.getOptions());
- r.resolveToplevelStatements(stmts);
- // Check that no closeBlock was forgotten.
- Preconditions.checkState(r.block.parent == null);
+ // Check that load() statements are on top.
+ if (r.options.requireLoadStatementsFirst()) {
+ r.checkLoadAfterStatement(stmts);
+ }
+
+ // First pass: creating bindings for statements in this block.
+ r.createBindingsForBlock(stmts);
+
+ // Second pass: visit all references.
+ r.visitAll(stmts);
+ r.closeBlock();
// If the final statement is an expression, synthesize a return statement.
int n = stmts.size();
@@ -791,7 +827,8 @@
/*body=*/ stmts,
/*hasVarargs=*/ false,
/*hasKwargs=*/ false,
- /*numKeywordOnlyParams=*/ 0));
+ /*numKeywordOnlyParams=*/ 0,
+ locals));
}
/**
@@ -804,7 +841,10 @@
List<SyntaxError> errors = new ArrayList<>();
Resolver r = new Resolver(errors, module, options);
+ ArrayList<Binding> locals = new ArrayList<>();
+ r.openBlock(Scope.LOCAL, locals); // for bindings in list comprehensions
r.visit(expr);
+ r.closeBlock();
if (!errors.isEmpty()) {
throw new SyntaxError.Exception(errors);
@@ -818,12 +858,13 @@
ImmutableList.of(ReturnStatement.make(expr)),
/*hasVarargs=*/ false,
/*hasKwargs=*/ false,
- /*numKeywordOnlyParams=*/ 0);
+ /*numKeywordOnlyParams=*/ 0,
+ locals);
}
- /** Open a new lexical block that will contain the future declarations. */
- private void openBlock(Scope scope) {
- block = new Block(scope, block);
+ /** Open a new lexical block for future bindings. Local bindings will be added to locals. */
+ private void openBlock(Scope scope, ArrayList<Binding> locals) {
+ block = new Block(scope, block, locals);
}
/** Close a lexical block (and lose all declarations it contained). */
diff --git a/src/test/java/net/starlark/java/eval/BUILD b/src/test/java/net/starlark/java/eval/BUILD
index 62579c9..4d0972d 100644
--- a/src/test/java/net/starlark/java/eval/BUILD
+++ b/src/test/java/net/starlark/java/eval/BUILD
@@ -30,6 +30,7 @@
"StarlarkThreadDebuggingTest.java",
"StarlarkThreadTest.java",
],
+ jvm_flags = ["-Dfile.encoding=UTF8"],
deps = [
"//src/main/java/net/starlark/java/annot",
"//src/main/java/net/starlark/java/eval",
diff --git a/src/test/java/net/starlark/java/eval/testdata/assign.star b/src/test/java/net/starlark/java/eval/testdata/assign.star
new file mode 100644
index 0000000..87076e7
--- /dev/null
+++ b/src/test/java/net/starlark/java/eval/testdata/assign.star
@@ -0,0 +1,6 @@
+# tests of assignment
+
+# computation in a[...]=x expression.
+a = [0, 1, 2, 3, 4, 5]
+a[[i for i in range(6) if i == 2][0]] = "z"
+assert_eq(a, [0, 1, "z", 3, 4, 5])
diff --git a/src/test/java/net/starlark/java/syntax/ResolverTest.java b/src/test/java/net/starlark/java/syntax/ResolverTest.java
index 979fe30..6df2f2e 100644
--- a/src/test/java/net/starlark/java/syntax/ResolverTest.java
+++ b/src/test/java/net/starlark/java/syntax/ResolverTest.java
@@ -13,8 +13,10 @@
// limitations under the License.
package net.starlark.java.syntax;
+import static com.google.common.truth.Truth.assertThat;
import static net.starlark.java.syntax.LexerTest.assertContainsError;
+import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import org.junit.Test;
@@ -385,4 +387,54 @@
assertValid("pre(0, a=0, *0, **0)");
}
+
+ @Test
+ public void testBindingScopeAndIndex() throws Exception {
+ checkBindings(
+ "xᴳ₀ = 0", //
+ "yᴳ₁ = 1",
+ "zᴳ₂ = 2",
+ "xᴳ₀(xᴳ₀, yᴳ₁, preᴾ₀)",
+ "[xᴸ₀ for xᴸ₀ in xᴳ₀ if yᴳ₁]",
+ "def fᴳ₃(xᴸ₀ = xᴳ₀):",
+ " xᴸ₀ = yᴸ₁",
+ " yᴸ₁ = zᴳ₂");
+
+ // Note: loads bind globally, for now.
+ checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')");
+
+ checkBindings(
+ "aᴳ₀, bᴳ₁ = 0, 0", //
+ "def fᴳ₂(aᴸ₀=bᴳ₁):",
+ " aᴸ₀, bᴳ₁",
+ " [(aᴸ₁, bᴳ₁) for aᴸ₁ in aᴸ₀]");
+
+ checkBindings("load('module', aᴳ₀='a', bᴳ₁='b')");
+ }
+
+ // checkBindings verifies the binding (scope and index) of each identifier.
+ // Every variable must be followed by a superscript letter (its scope)
+ // and a subscript numeral (its index). They are replaced by spaces, the
+ // file is resolved, and then the computed information is written over
+ // the spaces. The resulting string must match the input.
+ private void checkBindings(String... lines) throws Exception {
+ String src = Joiner.on("\n").join(lines);
+ StarlarkFile file = resolveFile(src.replaceAll("[₀₁₂₃₄₅₆₇₈₉ᴳᴸᴾᶠᶜ]", " "));
+ if (!file.ok()) {
+ throw new AssertionError("resolution failed: " + file.errors());
+ }
+ String[] out = new String[] {src};
+ new NodeVisitor() {
+ @Override
+ public void visit(Identifier id) {
+ // Replace ...x__... with ...xᴸ₀...
+ out[0] =
+ out[0].substring(0, id.getEndOffset())
+ + "ᴸᴳᶜᶠᴾ".charAt(id.getBinding().getScope().ordinal()) // follow order of enum
+ + "₀₁₂₃₄₅₆₇₈₉".charAt(id.getBinding().index) // 10 is plenty
+ + out[0].substring(id.getEndOffset() + 2);
+ }
+ }.visit(file);
+ assertThat(out[0]).isEqualTo(src);
+ }
}