Refactor: Parse return statements without an expression properly

This is an internal refactoring necessary for the Skylark linter.
It does not change any behavior.

RELNOTES: None
PiperOrigin-RevId: 166199367
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/Parser.java b/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
index 56e6233..e948b84 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/Parser.java
@@ -1389,15 +1389,12 @@
     int end = token.right;
     expect(TokenKind.RETURN);
 
-    Expression expression;
-    if (STATEMENT_TERMINATOR_SET.contains(token.kind)) {
-        // this None makes the AST not correspond to the source exactly anymore
-        expression = new Identifier("None");
-        setLocation(expression, start, end);
-    } else {
-        expression = parseExpression();
+    Expression expression = null;
+    if (!STATEMENT_TERMINATOR_SET.contains(token.kind)) {
+      expression = parseExpression();
+      end = expression.getLocation().getEndOffset();
     }
-    return setLocation(new ReturnStatement(expression), start, expression);
+    return setLocation(new ReturnStatement(expression), start, end);
   }
 
   // create a comment node
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/ReturnStatement.java b/src/main/java/com/google/devtools/build/lib/syntax/ReturnStatement.java
index db4b4de..5235e58 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/ReturnStatement.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/ReturnStatement.java
@@ -15,6 +15,7 @@
 
 import com.google.devtools.build.lib.events.Location;
 import java.io.IOException;
+import javax.annotation.Nullable;
 
 /**
  * A wrapper Statement class for return expressions.
@@ -46,17 +47,21 @@
     }
   }
 
-  private final Expression returnExpression;
+  @Nullable private final Expression returnExpression;
 
-  public ReturnStatement(Expression returnExpression) {
+  public ReturnStatement(@Nullable Expression returnExpression) {
     this.returnExpression = returnExpression;
   }
 
   @Override
   void doExec(Environment env) throws EvalException, InterruptedException {
+    if (returnExpression == null) {
+      throw new ReturnException(getLocation(), Runtime.NONE);
+    }
     throw new ReturnException(returnExpression.getLocation(), returnExpression.eval(env));
   }
 
+  @Nullable
   public Expression getReturnExpression() {
     return returnExpression;
   }
@@ -65,9 +70,7 @@
   public void prettyPrint(Appendable buffer, int indentLevel) throws IOException {
     printIndent(buffer, indentLevel);
     buffer.append("return");
-    // "return" with no arg is represented internally as returning the None identifier.
-    if (!(returnExpression instanceof Identifier
-          && ((Identifier) returnExpression).getName().equals("None"))) {
+    if (returnExpression != null) {
       buffer.append(' ');
       returnExpression.prettyPrint(buffer, indentLevel);
     }
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java b/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
index 288a7b0..9770a04 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/SyntaxTreeVisitor.java
@@ -129,7 +129,9 @@
   }
 
   public void visit(ReturnStatement node) {
-    visit(node.getReturnExpression());
+    if (node.getReturnExpression() != null) {
+      visit(node.getReturnExpression());
+    }
   }
 
   public void visit(FlowStatement node) {
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/UserDefinedFunction.java b/src/main/java/com/google/devtools/build/lib/syntax/UserDefinedFunction.java
index 6dae713..0740416 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/UserDefinedFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/UserDefinedFunction.java
@@ -80,7 +80,11 @@
           if (stmt instanceof ReturnStatement) {
             // Performance optimization.
             // Executing the statement would throw an exception, which is slow.
-            return ((ReturnStatement) stmt).getReturnExpression().eval(env);
+            Expression returnExpr = ((ReturnStatement) stmt).getReturnExpression();
+            if (returnExpr == null) {
+              return Runtime.NONE;
+            }
+            return returnExpr.eval(env);
           } else {
             stmt.exec(env);
           }