Build language: Fix evaluation of nested list comprehensions.

--
MOS_MIGRATED_REVID=93869549
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/ListComprehension.java b/src/main/java/com/google/devtools/build/lib/syntax/ListComprehension.java
index fde3b28..5d3b655 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/ListComprehension.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/ListComprehension.java
@@ -14,67 +14,133 @@
 
 package com.google.devtools.build.lib.syntax;
 
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
+import com.google.devtools.build.lib.events.Location;
 
+import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
-import java.util.Map;
+
+import javax.annotation.Nullable;
 
 /**
  * Syntax node for lists comprehension expressions.
+ *
+ * A list comprehension contains one or more clauses, e.g.
+ *   [a+d for a in b if c for d in e]
+ * contains three clauses: "for a in b", "if c", "for d in e".
+ * For and If clauses can happen in any order, except that the first one has to be a For.
+ *
+ * The code above can be expanded as:
+ * <pre>
+ *   for a in b:
+ *     if c:
+ *       for d in e:
+ *         result.append(a+d)
+ * </pre>
+ * result is initialized to [] and is the return value of the whole expression.
  */
 public final class ListComprehension extends Expression {
 
-  private final Expression elementExpression;
-  // This cannot be a map, because we need to both preserve order _and_ allow duplicate identifiers.
-  private final List<Map.Entry<LValue, Expression>> lists;
+  /**
+   * The interface implemented by ForClause and (later) IfClause.
+   * A list comprehension consists of one or many Clause.
+   */
+  public interface Clause extends Serializable {
+    /**
+     * The evaluation of the list comprehension is based on recursion. Each clause may
+     * call recursively evalStep (ForClause will call it multiple times, IfClause will
+     * call it zero or one time) which will evaluate the next clause. To know which clause
+     * is the next one, we pass a step argument (it represents the index in the clauses
+     * list). Results are aggregated in the result argument, and are populated by
+     * evalStep.
+     *
+     * @param env environment in which we do the evaluation.
+     * @param result the agreggated results of the list comprehension.
+     * @param step the index of the next clause to evaluate.
+     */
+    abstract void eval(Environment env, List<Object> result, int step)
+        throws EvalException, InterruptedException;
+
+    abstract void validate(ValidationEnvironment env) throws EvalException;
+
+    /**
+     * The LValue defined in Clause, i.e. the loop variables for ForClause and null for
+     * IfClause. This is needed for SyntaxTreeVisitor.
+     */
+    @Nullable  // for the IfClause
+    public abstract LValue getLValue();
+
+    /**
+     * The Expression defined in Clause, i.e. the collection for ForClause and the
+     * condition for IfClause. This is needed for SyntaxTreeVisitor.
+     */
+    public abstract Expression getExpression();
+  }
+
+  // TODO(bazel-team): Support IfClause
 
   /**
-   * [elementExpr (for var in listExpr)+]
+   * A for clause in a list comprehension, e.g. "for a in b" in the example above.
    */
-  public ListComprehension(Expression elementExpression) {
-    this.elementExpression = elementExpression;
-    lists = new ArrayList<>();
-  }
+  public final class ForClause implements Clause {
+    private final LValue variables;
+    private final Expression list;
 
-  @Override
-  Object eval(Environment env) throws EvalException, InterruptedException {
-    if (lists.isEmpty()) {
-      return convert(new ArrayList<>(), env);
+    public ForClause(LValue variables, Expression list) {
+      this.variables = variables;
+      this.list = list;
     }
 
-    List<Map.Entry<LValue, Iterable<?>>> listValues = Lists.newArrayListWithCapacity(lists.size());
-    int size = 1;
-    for (Map.Entry<LValue, Expression> list : lists) {
-      Object listValueObject = list.getValue().eval(env);
-      final Iterable<?> listValue = EvalUtils.toIterable(listValueObject, getLocation());
-      int listSize = EvalUtils.size(listValue);
-      if (listSize == 0) {
-        return convert(new ArrayList<>(), env);
+    @Override
+    public void eval(Environment env, List<Object> result, int step)
+        throws EvalException, InterruptedException {
+      Object listValueObject = list.eval(env);
+      Location loc = getLocation();
+      Iterable<?> listValue = EvalUtils.toIterable(listValueObject, loc);
+      for (Object listElement : listValue) {
+        variables.assign(env, loc, listElement);
+        evalStep(env, result, step);
       }
-      size *= listSize;
-      listValues.add(Maps.<LValue, Iterable<?>>immutableEntry(list.getKey(), listValue));
     }
-    List<Object> resultList = Lists.newArrayListWithCapacity(size);
-    evalLists(env, listValues, resultList);
-    return convert(resultList, env);
-  }
 
-  private Object convert(List<Object> list, Environment env) throws EvalException {
-    if (env.isSkylarkEnabled()) {
-      return SkylarkList.list(list, getLocation());
-    } else {
+    @Override
+    public void validate(ValidationEnvironment env) throws EvalException {
+      variables.validate(env, getLocation());
+      list.validate(env);
+    }
+
+    @Override
+    public LValue getLValue() {
+      return variables;
+    }
+
+    @Override
+    public Expression getExpression() {
       return list;
     }
+
+    @Override
+    public String toString() {
+      return String.format("for %s in %s", variables.toString(), EvalUtils.prettyPrintValue(list));
+    }
+  }
+
+  private List<Clause> clauses;
+  /** The return expression, e.g. "a+d" in the example above */
+  private final Expression elementExpression;
+
+  public ListComprehension(Expression elementExpression) {
+    this.elementExpression = elementExpression;
+    clauses = new ArrayList<>();
   }
 
   @Override
   public String toString() {
     StringBuilder sb = new StringBuilder();
     sb.append('[').append(elementExpression);
-    for (Map.Entry<LValue, Expression> list : lists) {
-      sb.append(" for ").append(list.getKey()).append(" in ").append(list.getValue());
+    for (Clause clause : clauses) {
+      sb.append(' ').append(clause.toString());
     }
     sb.append(']');
     return sb.toString();
@@ -84,12 +150,19 @@
     return elementExpression;
   }
 
+  /**
+   * Add a new ForClause to the list comprehension. This is used only by the parser and must
+   * not be called once AST is complete.
+   * TODO(bazel-team): Remove this side-effect. Clauses should be passed to the constructor
+   * instead.
+   */
   public void add(Expression loopVar, Expression listExpression) {
-    lists.add(Maps.immutableEntry(new LValue(loopVar), listExpression));
+    Clause forClause = new ForClause(new LValue(loopVar), listExpression);
+    clauses.add(forClause);
   }
 
-  public List<Map.Entry<LValue, Expression>> getLists() {
-    return lists;
+  public List<Clause> getClauses() {
+    return Collections.unmodifiableList(clauses);
   }
 
   @Override
@@ -97,34 +170,36 @@
     visitor.visit(this);
   }
 
-  /**
-   * Evaluates element expression over all combinations of list element values.
-   *
-   * <p>Iterates over all elements in outermost list (list at index 0) and
-   * updates the value of the list variable in the environment on each
-   * iteration. If there are no other lists to iterate over added evaluation
-   * of the element expression to the result. Otherwise calls itself recursively
-   * with all the lists except the outermost.
-   */
-  private void evalLists(Environment env, List<Map.Entry<LValue, Iterable<?>>> listValues,
-      List<Object> result) throws EvalException, InterruptedException {
-    Map.Entry<LValue, Iterable<?>> listValue = listValues.get(0);
-    for (Object listElement : listValue.getValue()) {
-      listValue.getKey().assign(env, getLocation(), listElement);
-      if (listValues.size() == 1) {
-        result.add(elementExpression.eval(env));
-      } else {
-        evalLists(env, listValues.subList(1, listValues.size()), result);
-      }
-    }
+  @Override
+  Object eval(Environment env) throws EvalException, InterruptedException {
+    List<Object> result = new ArrayList<>();
+    evalStep(env, result, 0);
+    return env.isSkylarkEnabled() ? SkylarkList.list(result, getLocation()) : result;
   }
 
   @Override
   void validate(ValidationEnvironment env) throws EvalException {
-    for (Map.Entry<LValue, Expression> list : lists) {
-      list.getValue().validate(env);
-      list.getKey().validate(env, getLocation());
+    for (Clause clause : clauses) {
+      clause.validate(env);
     }
     elementExpression.validate(env);
   }
+
+  /**
+   * Evaluate the clause indexed by step, or elementExpression. When we evaluate the list
+   * comprehension, step is 0 and we evaluate the first clause. Each clause may
+   * recursively call evalStep any number of times. After the last clause,
+   * elementExpression is evaluated and added to the results.
+   *
+   * In the expanded example above, you can consider that evalStep is equivalent to
+   * evaluating the line number step.
+   */
+  private void evalStep(Environment env, List<Object> result, int step)
+      throws EvalException, InterruptedException {
+    if (step >= clauses.size()) {
+      result.add(elementExpression.eval(env));
+    } else {
+      clauses.get(step).eval(env, result, step + 1);
+    }
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java b/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
index a373594..3127f71 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
@@ -17,7 +17,6 @@
 import com.google.devtools.build.lib.syntax.IfStatement.ConditionalStatements;
 
 import java.util.List;
-import java.util.Map;
 
 /**
  * A visitor for visiting the nodes in the syntax tree left to right, top to
@@ -61,9 +60,11 @@
 
   public void visit(ListComprehension node) {
     visit(node.getElementExpression());
-    for (Map.Entry<LValue, Expression> list : node.getLists()) {
-      visit(list.getKey().getExpression());
-      visit(list.getValue());
+    for (ListComprehension.Clause clause : node.getClauses()) {
+      if (clause.getLValue() != null) {
+        visit(clause.getLValue().getExpression());
+      }
+      visit(clause.getExpression());
     }
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/syntax/EvaluationTest.java b/src/test/java/com/google/devtools/build/lib/syntax/EvaluationTest.java
index 9d83155..b410fa0 100644
--- a/src/test/java/com/google/devtools/build/lib/syntax/EvaluationTest.java
+++ b/src/test/java/com/google/devtools/build/lib/syntax/EvaluationTest.java
@@ -362,6 +362,21 @@
   }
 
   @Test
+  public void testNestedListComprehensions() throws Exception {
+    assertThat((Iterable<?>) eval(
+          "li = [[1, 2], [3, 4]]\n"
+          + "[j for i in li for j in i]"))
+        .containsExactly(1, 2, 3, 4).inOrder();
+
+    assertThat((Iterable<?>) eval(
+          "input = [['abc'], ['def', 'ghi']]\n"
+          + "['%s %s' % (b, c) for a in input for b in a for c in b]"))
+        .containsExactly(
+            "abc a", "abc b", "abc c", "def d", "def e", "def f", "ghi g", "ghi h", "ghi i")
+        .inOrder();
+  }
+
+  @Test
   public void testListComprehensionsMultipleVariables() throws Exception {
     assertThat(eval("[x + y for x, y in [(1, 2), (3, 4)]]").toString())
         .isEqualTo("[3, 7]");
diff --git a/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java b/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
index 613a648..e5d4815 100644
--- a/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
+++ b/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
@@ -605,20 +605,32 @@
   }
 
   @Test
+  public void testListComprehensionEmptyList() throws Exception {
+    List<ListComprehension.Clause> clauses = ((ListComprehension) parseExpression(
+        "['foo/%s.java' % x for x in []]")).getClauses();
+    assertThat(clauses).hasSize(1);
+    assertThat(clauses.get(0).getExpression().toString()).isEqualTo("[]");
+    assertThat(clauses.get(0).getLValue().getExpression().toString()).isEqualTo("x");
+  }
+
+  @Test
   public void testListComprehension() throws Exception {
-    ListComprehension list =
-      (ListComprehension) parseExpression(
-          "['foo/%s.java' % x "
-          + "for x in []]");
-    assertThat(list.getLists()).hasSize(1);
+    List<ListComprehension.Clause> clauses = ((ListComprehension) parseExpression(
+        "['foo/%s.java' % x for x in ['bar', 'wiz', 'quux']]")).getClauses();
+    assertThat(clauses).hasSize(1);
+    assertThat(clauses.get(0).getLValue().getExpression().toString()).isEqualTo("x");
+    assertThat(clauses.get(0).getExpression()).isInstanceOf(ListLiteral.class);
+  }
 
-    list = (ListComprehension) parseExpression("['foo/%s.java' % x "
-        + "for x in ['bar', 'wiz', 'quux']]");
-    assertThat(list.getLists()).hasSize(1);
-
-    list = (ListComprehension) parseExpression("['%s/%s.java' % (x, y) "
-        + "for x in ['foo', 'bar'] for y in ['baz', 'wiz', 'quux']]");
-    assertThat(list.getLists()).hasSize(2);
+  @Test
+  public void testForForListComprehension() throws Exception {
+    List<ListComprehension.Clause> clauses = ((ListComprehension) parseExpression(
+        "['%s/%s.java' % (x, y) for x in ['foo', 'bar'] for y in list]")).getClauses();
+    assertThat(clauses).hasSize(2);
+    assertThat(clauses.get(0).getLValue().getExpression().toString()).isEqualTo("x");
+    assertThat(clauses.get(0).getExpression()).isInstanceOf(ListLiteral.class);
+    assertThat(clauses.get(1).getLValue().getExpression().toString()).isEqualTo("y");
+    assertThat(clauses.get(1).getExpression()).isInstanceOf(Ident.class);
   }
 
   @Test