bazel syntax: catch stack overflow in parser

PiperOrigin-RevId: 313235589
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/Parser.java b/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
index 101220b..1627df3 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
@@ -15,8 +15,10 @@
 package com.google.devtools.build.lib.syntax;
 
 import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.errorprone.annotations.FormatMethod;
 import java.util.ArrayList;
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -206,11 +208,23 @@
     List<SyntaxError> errors = new ArrayList<>();
     Lexer lexer = new Lexer(input, options, errors);
     Parser parser = new Parser(lexer, errors);
-    Expression result = parser.parseExpression();
-    while (parser.token.kind == TokenKind.NEWLINE) {
-      parser.nextToken();
+    Expression result = null;
+    try {
+      result = parser.parseExpression();
+      while (parser.token.kind == TokenKind.NEWLINE) {
+        parser.nextToken();
+      }
+      parser.expect(TokenKind.EOF);
+    } catch (StackOverflowError ex) {
+      // See rationale at parseFile.
+      parser.reportError(
+          lexer.end,
+          "internal error: stack overflow while parsing Starlark expression <<%s>>. Please report"
+              + " the bug.\n"
+              + "%s",
+          new String(input.getContent()),
+          Throwables.getStackTraceAsString(ex));
     }
-    parser.expect(TokenKind.EOF);
     if (!errors.isEmpty()) {
       throw new SyntaxError.Exception(errors);
     }
@@ -237,22 +251,24 @@
     return new ListExpression(locs, /*isTuple=*/ true, -1, elems, -1);
   }
 
-  private void reportError(int offset, String message) {
+  @FormatMethod
+  private void reportError(int offset, String format, Object... args) {
     errorsCount++;
     // Limit the number of reported errors to avoid spamming output.
     if (errorsCount <= 5) {
       Location location = locs.getLocation(offset);
-      errors.add(new SyntaxError(location, message));
+      errors.add(new SyntaxError(location, String.format(format, args)));
     }
   }
 
   private void syntaxError(String message) {
     if (!recoveryMode) {
-      String msg =
-          token.kind == TokenKind.INDENT
-              ? "indentation error"
-              : "syntax error at '" + tokenString(token.kind, token.value) + "': " + message;
-      reportError(token.start, msg);
+      if (token.kind == TokenKind.INDENT) {
+        reportError(token.start, "indentation error");
+      } else {
+        reportError(
+            token.start, "syntax error at '%s': %s", tokenString(token.kind, token.value), message);
+      }
       recoveryMode = true;
     }
   }
@@ -352,7 +368,7 @@
         error = "keyword '" + token.kind + "' not supported";
         break;
     }
-    reportError(token.start, error);
+    reportError(token.start, "%s", error);
   }
 
   private int nextToken() {
@@ -865,8 +881,9 @@
       if (lastOp != null && operatorPrecedence.get(prec).contains(TokenKind.EQUALS_EQUALS)) {
         reportError(
             token.start,
-            String.format(
-                "Operator '%s' is not associative with operator '%s'. Use parens.", lastOp, op));
+            "Operator '%s' is not associative with operator '%s'. Use parens.",
+            lastOp,
+            op);
       }
 
       int opOffset = nextToken();
@@ -930,17 +947,38 @@
   // file_input = ('\n' | stmt)* EOF
   private List<Statement> parseFileInput() {
     List<Statement> list =  new ArrayList<>();
-    while (token.kind != TokenKind.EOF) {
-      if (token.kind == TokenKind.NEWLINE) {
-        expectAndRecover(TokenKind.NEWLINE);
-      } else if (recoveryMode) {
-        // If there was a parse error, we want to recover here
-        // before starting a new top-level statement.
-        syncTo(STATEMENT_TERMINATOR_SET);
-        recoveryMode = false;
-      } else {
-        parseStatement(list);
+    try {
+      while (token.kind != TokenKind.EOF) {
+        if (token.kind == TokenKind.NEWLINE) {
+          expectAndRecover(TokenKind.NEWLINE);
+        } else if (recoveryMode) {
+          // If there was a parse error, we want to recover here
+          // before starting a new top-level statement.
+          syncTo(STATEMENT_TERMINATOR_SET);
+          recoveryMode = false;
+        } else {
+          parseStatement(list);
+        }
       }
+    } catch (StackOverflowError ex) {
+      // JVM threads have very limited stack, and deeply nested inputs can
+      // easily cause the parser to consume all available stack. It is hard
+      // to anticipate all the possible recursions in the parser, especially
+      // when considering error recovery. Consider a long list of dicts:
+      // even if the intended parse tree has a depth of only two,
+      // if each dict contains a syntax error, the parser will go into recovery
+      // and may discard each dict's closing '}', turning a shallow tree
+      // into a deep one (see b/157470754).
+      //
+      // So, for robustness, the parser treats StackOverflowError as a parse
+      // error, exhorting the user to report a bug.
+      reportError(
+          token.end,
+          "internal error: stack overflow in Starlark parser. Please report the bug and include"
+              + " the text of %s.\n"
+              + "%s",
+          locs.file(),
+          Throwables.getStackTraceAsString(ex));
     }
     return list;
   }
diff --git a/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java b/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
index 9f87e1d..d64aae2 100644
--- a/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
+++ b/src/test/java/com/google/devtools/build/lib/syntax/ParserTest.java
@@ -15,6 +15,7 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.truth.Truth.assertWithMessage;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.fail;
 
 import com.google.common.collect.ImmutableList;
@@ -1283,4 +1284,32 @@
     assertThat(parseExpressionError("1 if 2"))
         .contains("missing else clause in conditional expression or semicolon before if");
   }
+
+  @Test
+  public void testParseFileStackOverflow() throws Exception {
+    StarlarkFile file = StarlarkFile.parse(veryDeepExpression());
+    SyntaxError ex = LexerTest.assertContainsError(file.errors(), "internal error: stack overflow");
+    assertThat(ex.message()).contains("parseDictEntry"); // includes stack
+    assertThat(ex.message()).contains("Please report the bug");
+    assertThat(ex.message()).contains("include the text of foo.star"); // includes file name
+  }
+
+  @Test
+  public void testParseExpressionStackOverflow() throws Exception {
+    SyntaxError.Exception ex =
+        assertThrows(SyntaxError.Exception.class, () -> Expression.parse(veryDeepExpression()));
+    SyntaxError err = LexerTest.assertContainsError(ex.errors(), "internal error: stack overflow");
+    assertThat(err.message()).contains("parseDictEntry"); // includes stack
+    assertThat(err.message())
+        .contains("while parsing Starlark expression <<{{{{"); // includes expression
+    assertThat(err.message()).contains("Please report the bug");
+  }
+
+  private static ParserInput veryDeepExpression() {
+    StringBuilder s = new StringBuilder();
+    for (int i = 0; i < 1000; i++) {
+      s.append("{");
+    }
+    return ParserInput.create(s.toString(), "foo.star");
+  }
 }