Allow passing location, ast, and environment to @SkylarkCallable methods

RELNOTES: None.
PiperOrigin-RevId: 188201686
diff --git a/src/main/java/com/google/devtools/build/lib/skylarkinterface/SkylarkCallable.java b/src/main/java/com/google/devtools/build/lib/skylarkinterface/SkylarkCallable.java
index 1b193fe..3ce0f70 100644
--- a/src/main/java/com/google/devtools/build/lib/skylarkinterface/SkylarkCallable.java
+++ b/src/main/java/com/google/devtools/build/lib/skylarkinterface/SkylarkCallable.java
@@ -20,6 +20,22 @@
 
 /**
  * A marker interface for Java methods which can be called from Skylark.
+ *
+ * <p>Methods annotated with this annotation are expected to meet certain requirements which
+ * are enforced by an annotation processor:</p>
+ * <ul>
+ * <li>The method must be public.</li>
+ * <li>If structField=true, there must be zero user-supplied parameters.</li>
+ * <li>Method parameters must be supplied in the following order:
+ *   <pre>method([positionals]*[other user-args](Location)(FuncallExpression)(Envrionment))</pre>
+ *   where Location, FuncallExpression, and Environment are supplied by the interpreter if and
+ *   only if useLocation, useAst, and useEnvironment are specified, respectively.
+*  </li>
+ * <li>
+ *   The number of method parameters much match the number of annotation-declared parameters
+ *   plus the number of interpreter-supplied parameters.
+ * </li>
+ * </ul>
  */
 @Target({ElementType.METHOD})
 @Retention(RetentionPolicy.RUNTIME)
@@ -70,4 +86,25 @@
    * <code>None</code>). If not set and the Java method returns null, an error will be raised.
    */
   boolean allowReturnNones() default false;
+
+  /**
+   * If true, the location of the call site will be passed as an argument of the annotated function.
+   * (Thus, the annotated method signature must contain Location as a parameter. See the
+   * interface-level javadoc for details.)
+   */
+  boolean useLocation() default false;
+
+  /**
+   * If true, the AST of the call site will be passed as an argument of the annotated function.
+   * (Thus, the annotated method signature must contain FuncallExpression as a parameter. See the
+   * interface-level javadoc for details.)
+   */
+  boolean useAst() default false;
+
+  /**
+   * If true, the Skylark Environment will be passed as an argument of the annotated function.
+   * (Thus, the annotated method signature must contain Environment as a parameter. See the
+   * interface-level javadoc for details.)
+   */
+  boolean useEnvironment() default false;
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skylarkinterface/processor/SkylarkCallableProcessor.java b/src/main/java/com/google/devtools/build/lib/skylarkinterface/processor/SkylarkCallableProcessor.java
index cc98597..cd4b288 100644
--- a/src/main/java/com/google/devtools/build/lib/skylarkinterface/processor/SkylarkCallableProcessor.java
+++ b/src/main/java/com/google/devtools/build/lib/skylarkinterface/processor/SkylarkCallableProcessor.java
@@ -15,6 +15,7 @@
 package com.google.devtools.build.lib.skylarkinterface.processor;
 
 import com.google.devtools.build.lib.skylarkinterface.SkylarkCallable;
+import java.util.List;
 import java.util.Set;
 import javax.annotation.processing.AbstractProcessor;
 import javax.annotation.processing.Messager;
@@ -27,6 +28,7 @@
 import javax.lang.model.element.ExecutableElement;
 import javax.lang.model.element.Modifier;
 import javax.lang.model.element.TypeElement;
+import javax.lang.model.element.VariableElement;
 import javax.tools.Diagnostic;
 
 /**
@@ -35,8 +37,16 @@
  * <p>Checks the following invariants about {@link SkylarkCallable}-annotated methods:
  * <ul>
  * <li>The method must be public.</li>
- * <li>The number of method parameters much match the number of annotation-declared parameters.</li>
- * <li>If structField=true, there must be zero arguments.</li>
+ * <li>If structField=true, there must be zero user-supplied parameters.</li>
+ * <li>Method parameters must be supplied in the following order:
+ *   <pre>method([positionals]*[other user-args](Location)(FuncallExpression)(Envrionment))</pre>
+ *   where Location, FuncallExpression, and Environment are supplied by the interpreter if and
+ *   only if useLocation, useAst, and useEnvironment are specified, respectively.
+ *  </li>
+ * <li>
+ *   The number of method parameters much match the number of annotation-declared parameters
+ *   plus the number of interpreter-supplied parameters.
+ * </li>
  * </ul>
  *
  * <p>These properties can be relied upon at runtime without additional checks.
@@ -47,6 +57,10 @@
 
   private Messager messager;
 
+  private static final String LOCATION = "com.google.devtools.build.lib.events.Location";
+  private static final String AST = "com.google.devtools.build.lib.syntax.FuncallExpression";
+  private static final String ENVIRONMENT = "com.google.devtools.build.lib.syntax.Environment";
+
   @Override
   public synchronized void init(ProcessingEnvironment processingEnv) {
     super.init(processingEnv);
@@ -56,6 +70,7 @@
   @Override
   public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
     for (Element element : roundEnv.getElementsAnnotatedWith(SkylarkCallable.class)) {
+
       // Only methods are annotated with SkylarkCallable. This is verified by the
       // @Target(ElementType.METHOD) annotation.
       ExecutableElement methodElement = (ExecutableElement) element;
@@ -64,32 +79,115 @@
       if (!methodElement.getModifiers().contains(Modifier.PUBLIC)) {
         error(methodElement, "@SkylarkCallable annotated methods must be public.");
       }
-      if (annotation.parameters().length > 0 || annotation.mandatoryPositionals() >= 0) {
-         int numDeclaredArgs = annotation.parameters().length
-             + Math.max(0, annotation.mandatoryPositionals());
-        if (methodElement.getParameters().size() != numDeclaredArgs) {
-          error(methodElement, String.format(
-              "@SkylarkCallable annotated method has %d parameters, but annotation declared %d.",
-              methodElement.getParameters().size(), numDeclaredArgs));
-        }
-      }
-      if (annotation.structField()) {
-        if (!methodElement.getParameters().isEmpty()) {
-          error(methodElement,
-              "@SkylarkCallable annotated methods with structField=true must have zero arguments.");
-        }
+
+      try {
+        verifyNumberOfParameters(methodElement, annotation);
+        verifyExtraInterpreterParams(methodElement, annotation);
+      } catch (SkylarkCallableProcessorException exception) {
+        error(exception.methodElement, exception.errorMessage);
       }
     }
+
     return true;
   }
 
+  private void verifyNumberOfParameters(ExecutableElement methodElement, SkylarkCallable annotation)
+      throws SkylarkCallableProcessorException {
+    List<? extends VariableElement> methodSignatureParams = methodElement.getParameters();
+    int numExtraInterpreterParams = numExpectedExtraInterpreterParams(annotation);
+
+    if (annotation.parameters().length > 0 || annotation.mandatoryPositionals() >= 0) {
+      int numDeclaredArgs =
+          annotation.parameters().length + Math.max(0, annotation.mandatoryPositionals());
+      if (methodSignatureParams.size() != numDeclaredArgs + numExtraInterpreterParams) {
+        throw new SkylarkCallableProcessorException(
+            methodElement,
+            String.format(
+                "@SkylarkCallable annotated method has %d parameters, but annotation declared "
+                    + "%d user-supplied parameters and %d extra interpreter parameters.",
+                methodSignatureParams.size(), numDeclaredArgs, numExtraInterpreterParams));
+      }
+    }
+    if (annotation.structField()) {
+      if (methodSignatureParams.size() > 0) {
+        // TODO(cparsons): Allow structField methods to accept interpreter-supplied arguments.
+        throw new SkylarkCallableProcessorException(
+            methodElement,
+            "@SkylarkCallable annotated methods with structField=true must have zero arguments.");
+      }
+    }
+  }
+
+  private void verifyExtraInterpreterParams(ExecutableElement methodElement,
+      SkylarkCallable annotation) throws SkylarkCallableProcessorException {
+    List<? extends VariableElement> methodSignatureParams = methodElement.getParameters();
+    int currentIndex = methodSignatureParams.size() - numExpectedExtraInterpreterParams(annotation);
+
+    // TODO(cparsons): Matching by class name alone is somewhat brittle, but due to tangled
+    // dependencies, it is difficult for this processor to depend directy on the expected
+    // classes here.
+    if (annotation.useLocation()) {
+      if (!LOCATION.equals(methodSignatureParams.get(currentIndex).asType().toString())) {
+        throw new SkylarkCallableProcessorException(
+            methodElement,
+            String.format(
+                "Expected parameter index %d to be the %s type, matching useLocation, but was %s",
+                currentIndex,
+                LOCATION,
+                methodSignatureParams.get(currentIndex).asType().toString()));
+      }
+      currentIndex++;
+    }
+    if (annotation.useAst()) {
+      if (!AST.equals(methodSignatureParams.get(currentIndex).asType().toString())) {
+        throw new SkylarkCallableProcessorException(
+            methodElement,
+            String.format(
+                "Expected parameter index %d to be the %s type, matching useAst, but was %s",
+                currentIndex, AST, methodSignatureParams.get(currentIndex).asType().toString()));
+      }
+      currentIndex++;
+    }
+    if (annotation.useEnvironment()) {
+      if (!ENVIRONMENT.equals(methodSignatureParams.get(currentIndex).asType().toString())) {
+        throw new SkylarkCallableProcessorException(
+            methodElement,
+            String.format(
+                "Expected parameter index %d to be the %s type, matching useEnvironment, "
+                    + "but was %s",
+                currentIndex,
+                ENVIRONMENT,
+                methodSignatureParams.get(currentIndex).asType().toString()));
+      }
+    }
+  }
+
+  private int numExpectedExtraInterpreterParams(SkylarkCallable annotation) {
+    int numExtraInterpreterParams = 0;
+    numExtraInterpreterParams += annotation.useLocation() ? 1 : 0;
+    numExtraInterpreterParams += annotation.useAst() ? 1 : 0;
+    numExtraInterpreterParams += annotation.useEnvironment() ? 1 : 0;
+    return numExtraInterpreterParams;
+  }
+
   /**
    * Prints an error message & fails the compilation.
    *
    * @param e The element which has caused the error. Can be null
    * @param msg The error message
    */
-  public void error(Element e, String msg) {
+  private void error(Element e, String msg) {
     messager.printMessage(Diagnostic.Kind.ERROR, msg, e);
   }
+
+  private static class SkylarkCallableProcessorException extends Exception {
+    private final ExecutableElement methodElement;
+    private final String errorMessage;
+
+    private SkylarkCallableProcessorException(
+        ExecutableElement methodElement, String errorMessage) {
+      this.methodElement = methodElement;
+      this.errorMessage = errorMessage;
+    }
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/FuncallExpression.java b/src/main/java/com/google/devtools/build/lib/syntax/FuncallExpression.java
index 2639c27..841a94a 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/FuncallExpression.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/FuncallExpression.java
@@ -261,18 +261,18 @@
     return numPositionalArgs;
   }
 
-   @Override
-   public void prettyPrint(Appendable buffer) throws IOException {
-     function.prettyPrint(buffer);
-     buffer.append('(');
-     String sep = "";
-     for (Argument.Passed arg : arguments) {
-       buffer.append(sep);
-       arg.prettyPrint(buffer);
-       sep = ", ";
-     }
-     buffer.append(')');
-   }
+  @Override
+  public void prettyPrint(Appendable buffer) throws IOException {
+    function.prettyPrint(buffer);
+    buffer.append('(');
+    String sep = "";
+    for (Argument.Passed arg : arguments) {
+      buffer.append(sep);
+      arg.prettyPrint(buffer);
+      sep = ", ";
+    }
+    buffer.append(')');
+  }
 
   @Override
   public String toString() {
@@ -361,7 +361,7 @@
               "method invocation returned None, please file a bug report: "
                   + methodName
                   + Printer.printAbbreviatedList(
-                      ImmutableList.copyOf(args), "(", ", ", ")", null));
+                  ImmutableList.copyOf(args), "(", ", ", ")", null));
         }
       }
       // TODO(bazel-team): get rid of this, by having everyone use the Skylark data structures
@@ -394,7 +394,11 @@
   // exactly and copy that behaviour.
   // Throws an EvalException when it cannot find a matching function.
   private Pair<MethodDescriptor, List<Object>> findJavaMethod(
-      Class<?> objClass, String methodName, List<Object> args, Map<String, Object> kwargs)
+      Class<?> objClass,
+      String methodName,
+      List<Object> args,
+      Map<String, Object> kwargs,
+      Environment environment)
       throws EvalException {
     Pair<MethodDescriptor, List<Object>> matchingMethod = null;
     List<MethodDescriptor> methods = getMethods(objClass, methodName);
@@ -402,9 +406,10 @@
     if (methods != null) {
       for (MethodDescriptor method : methods) {
         if (method.getAnnotation().structField()) {
+          // TODO(cparsons): Allow structField methods to accept interpreter-supplied arguments.
           return new Pair<>(method, null);
         } else {
-          argumentListConversionResult = convertArgumentList(args, kwargs, method);
+          argumentListConversionResult = convertArgumentList(args, kwargs, method, environment);
           if (argumentListConversionResult.getArguments() != null) {
             if (matchingMethod == null) {
               matchingMethod = new Pair<>(method, argumentListConversionResult.getArguments());
@@ -471,26 +476,36 @@
    * any. If there is a type or argument mismatch, returns a result containing an error message.
    */
   private ArgumentListConversionResult convertArgumentList(
-      List<Object> args, Map<String, Object> kwargs, MethodDescriptor method) {
+      List<Object> args,
+      Map<String, Object> kwargs,
+      MethodDescriptor method,
+      Environment environment) {
     ImmutableList.Builder<Object> builder = ImmutableList.builder();
-    Class<?>[] params = method.getMethod().getParameterTypes();
+    Class<?>[] javaMethodSignatureParams = method.getMethod().getParameterTypes();
     SkylarkCallable callable = method.getAnnotation();
+    int numExtraInterpreterParams = 0;
+    numExtraInterpreterParams += callable.useLocation() ? 1 : 0;
+    numExtraInterpreterParams += callable.useAst() ? 1 : 0;
+    numExtraInterpreterParams += callable.useEnvironment() ? 1 : 0;
+
     int mandatoryPositionals = callable.mandatoryPositionals();
     if (mandatoryPositionals < 0) {
       if (callable.parameters().length > 0) {
         mandatoryPositionals = 0;
       } else {
-        mandatoryPositionals = params.length;
+        mandatoryPositionals = javaMethodSignatureParams.length - numExtraInterpreterParams;
       }
     }
-    if (mandatoryPositionals > args.size()
-        || args.size() > mandatoryPositionals + callable.parameters().length) {
+    if (mandatoryPositionals > args.size()) {
+      return ArgumentListConversionResult.fromError("too few arguments");
+    }
+    if (args.size() > mandatoryPositionals + callable.parameters().length) {
       return ArgumentListConversionResult.fromError("too many arguments");
     }
     // First process the legacy positional parameters.
     int i = 0;
     if (mandatoryPositionals > 0) {
-      for (Class<?> param : params) {
+      for (Class<?> param : javaMethodSignatureParams) {
         Object value = args.get(i);
         if (!param.isAssignableFrom(value.getClass())) {
           return ArgumentListConversionResult.fromError(
@@ -500,7 +515,7 @@
         }
         builder.add(value);
         i++;
-        if (mandatoryPositionals >= 0 && i >= mandatoryPositionals) {
+        if (i >= mandatoryPositionals) {
           // Stops for specified parameters instead.
           break;
         }
@@ -554,9 +569,21 @@
     if (!keys.isEmpty()) {
       return ArgumentListConversionResult.fromError(
           String.format("unexpected keyword%s %s",
-          keys.size() > 1 ? "s" : "",
-          Joiner.on(",").join(Iterables.transform(keys, s -> "'" + s + "'"))));
+              keys.size() > 1 ? "s" : "",
+              Joiner.on(",").join(Iterables.transform(keys, s -> "'" + s + "'"))));
     }
+
+    // Then add any skylark-info arguments (for example the Environment).
+    if (callable.useLocation()) {
+      builder.add(getLocation());
+    }
+    if (callable.useAst()) {
+      builder.add(this);
+    }
+    if (callable.useEnvironment()) {
+      builder.add(environment);
+    }
+
     return ArgumentListConversionResult.fromArgumentList(builder.build());
   }
 
@@ -690,7 +717,7 @@
         objClass = value.getClass();
       }
       Pair<MethodDescriptor, List<Object>> javaMethod =
-          call.findJavaMethod(objClass, method, positionalArgs, keyWordArgs);
+          call.findJavaMethod(objClass, method, positionalArgs, keyWordArgs, env);
       if (javaMethod.first.getAnnotation().structField()) {
         // Not a method but a callable attribute
         try {