Add register_toolchains function to the WORKSPACE for registering toolchains to use.

Part of #2219.

Change-Id: Id6dfe6ec102f609bb19461242a098bf977be29ae
PiperOrigin-RevId: 161527986
diff --git a/src/main/java/com/google/devtools/build/lib/packages/Package.java b/src/main/java/com/google/devtools/build/lib/packages/Package.java
index e5a5598..4ec65dd 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/Package.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/Package.java
@@ -193,6 +193,8 @@
   private ImmutableList<Event> events;
   private ImmutableList<Postable> posts;
 
+  private ImmutableList<Label> registeredToolchainLabels;
+
   /**
    * Package initialization, part 1 of 3: instantiates a new package with the
    * given name.
@@ -317,6 +319,7 @@
     this.features = ImmutableSortedSet.copyOf(builder.features);
     this.events = ImmutableList.copyOf(builder.events);
     this.posts = ImmutableList.copyOf(builder.posts);
+    this.registeredToolchainLabels = ImmutableList.copyOf(builder.registeredToolchainLabels);
   }
 
   /**
@@ -643,6 +646,10 @@
     return defaultRestrictedTo;
   }
 
+  public ImmutableList<Label> getRegisteredToolchainLabels() {
+    return registeredToolchainLabels;
+  }
+
   @Override
   public String toString() {
     return "Package(" + name + ")="
@@ -694,6 +701,7 @@
    * {@link com.google.devtools.build.lib.skyframe.PackageFunction}.
    */
   public static class Builder {
+
     public static interface Helper {
       /**
        * Returns a fresh {@link Package} instance that a {@link Builder} will internally mutate
@@ -756,6 +764,8 @@
 
     protected ExternalPackageBuilder externalPackageData = new ExternalPackageBuilder();
 
+    protected List<Label> registeredToolchainLabels = new ArrayList<>();
+
     /**
      * True iff the "package" function has already been called in this package.
      */
@@ -1270,6 +1280,10 @@
       addRuleUnchecked(rule);
     }
 
+    void addRegisteredToolchainLabels(List<Label> toolchains) {
+      this.registeredToolchainLabels.addAll(toolchains);
+    }
+
     private Builder beforeBuild(boolean discoverAssumedInputFiles) throws InterruptedException {
       Preconditions.checkNotNull(pkg);
       Preconditions.checkNotNull(filename);
diff --git a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
index 41dbdd7..ccb81e1 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
@@ -46,11 +46,14 @@
 import com.google.devtools.build.lib.syntax.Mutability;
 import com.google.devtools.build.lib.syntax.ParserInputSource;
 import com.google.devtools.build.lib.syntax.Runtime;
+import com.google.devtools.build.lib.syntax.Runtime.NoneType;
 import com.google.devtools.build.lib.syntax.SkylarkList;
 import com.google.devtools.build.lib.syntax.SkylarkSignatureProcessor;
 import com.google.devtools.build.lib.vfs.Path;
 import java.io.File;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -60,7 +63,6 @@
  * Parser for WORKSPACE files.  Fills in an ExternalPackage.Builder
  */
 public class WorkspaceFactory {
-  public static final String BIND = "bind";
   private static final Pattern LEGAL_WORKSPACE_NAME = Pattern.compile("^\\p{Alpha}\\w*$");
 
   // List of static function added by #addWorkspaceFunctions. Used to trim them out from the
@@ -260,6 +262,7 @@
     if (aPackage.containsErrors()) {
       builder.setContainsErrors();
     }
+    builder.addRegisteredToolchainLabels(aPackage.getRegisteredToolchainLabels());
     for (Rule rule : aPackage.getTargets(Rule.class)) {
       try {
         // The old rule references another Package instance and we wan't to keep the invariant that
@@ -382,6 +385,56 @@
     };
   }
 
+  @SkylarkSignature(
+    name = "register_toolchains",
+    objectType = Object.class,
+    returnType = NoneType.class,
+    doc =
+        "Registers a toolchain created with the toolchain() rule so that it is available for "
+            + "toolchain resolution.",
+    extraPositionals =
+        @Param(
+          name = "toolchain_labels",
+          type = SkylarkList.class,
+          generic1 = String.class,
+          doc = "The labels of the toolchains to register."
+        ),
+    useAst = true,
+    useEnvironment = true
+  )
+  private static final BuiltinFunction.Factory newRegisterToolchainsFunction =
+      new BuiltinFunction.Factory("register_toolchains") {
+        public BuiltinFunction create(final RuleFactory ruleFactory) {
+          return new BuiltinFunction(
+              "register_toolchains", FunctionSignature.POSITIONALS, BuiltinFunction.USE_AST_ENV) {
+            public Object invoke(
+                SkylarkList<String> toolchainLabels, FuncallExpression ast, Environment env)
+                throws EvalException, InterruptedException {
+
+              // Collect the toolchain labels.
+              List<Label> toolchains = new ArrayList<>();
+              for (String rawLabel :
+                  toolchainLabels.getContents(String.class, "toolchain_labels")) {
+                try {
+                  toolchains.add(Label.parseAbsolute(rawLabel));
+                } catch (LabelSyntaxException e) {
+                  throw new EvalException(
+                      ast.getLocation(),
+                      String.format("Unable to parse toolchain %s: %s", rawLabel, e.getMessage()),
+                      e);
+                }
+              }
+
+              // Add to the package definition for later.
+              Package.Builder builder = PackageFactory.getContext(env, ast).pkgBuilder;
+              builder.addRegisteredToolchainLabels(toolchains);
+
+              return NONE;
+            }
+          };
+        }
+      };
+
   /**
    * Returns a function-value implementing the build rule "ruleClass" (e.g. cc_library) in the
    * specified package context.
@@ -426,15 +479,16 @@
 
   private static ImmutableMap<String, BaseFunction> createWorkspaceFunctions(
       boolean allowOverride, RuleFactory ruleFactory) {
-    ImmutableMap.Builder<String, BaseFunction> mapBuilder = ImmutableMap.builder();
-    mapBuilder.put(BIND, newBindFunction(ruleFactory));
+    Map<String, BaseFunction> map = new HashMap<>();
+    map.put("bind", newBindFunction(ruleFactory));
+    map.put("register_toolchains", newRegisterToolchainsFunction.apply(ruleFactory));
     for (String ruleClass : ruleFactory.getRuleClassNames()) {
-      if (!ruleClass.equals(BIND)) {
+      if (!map.containsKey(ruleClass)) {
         BaseFunction ruleFunction = newRuleFunction(ruleFactory, ruleClass, allowOverride);
-        mapBuilder.put(ruleClass, ruleFunction);
+        map.put(ruleClass, ruleFunction);
       }
     }
-    return mapBuilder.build();
+    return ImmutableMap.copyOf(map);
   }
 
   private void addWorkspaceFunctions(Environment workspaceEnv, StoredEventHandler localReporter) {
diff --git a/src/main/java/com/google/devtools/build/lib/rules/ExternalPackageUtil.java b/src/main/java/com/google/devtools/build/lib/rules/ExternalPackageUtil.java
index 4a73d05..3a68b46 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/ExternalPackageUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/ExternalPackageUtil.java
@@ -26,6 +26,7 @@
 import com.google.devtools.build.lib.packages.Package;
 import com.google.devtools.build.lib.packages.Rule;
 import com.google.devtools.build.lib.skyframe.PackageLookupValue;
+import com.google.devtools.build.lib.skyframe.PackageValue;
 import com.google.devtools.build.lib.skyframe.WorkspaceFileValue;
 import com.google.devtools.build.lib.syntax.EvalException;
 import com.google.devtools.build.lib.util.Preconditions;
@@ -171,6 +172,24 @@
     return rule;
   }
 
+  /**
+   * Loads the external package and then returns the registered toolchain labels.
+   *
+   * @param env the environment to use for lookups
+   */
+  @Nullable
+  public static List<Label> getRegisteredToolchainLabels(Environment env)
+      throws ExternalPackageException, InterruptedException {
+    PackageValue externalPackageValue =
+        (PackageValue) env.getValue(PackageValue.key(Label.EXTERNAL_PACKAGE_IDENTIFIER));
+    if (externalPackageValue == null) {
+      return null;
+    }
+
+    Package externalPackage = externalPackageValue.getPackage();
+    return externalPackage.getRegisteredToolchainLabels();
+  }
+
   /** Exception thrown when something goes wrong accessing a rule. */
   public static class ExternalPackageException extends SkyFunctionException {
     public ExternalPackageException(NoSuchPackageException cause, Transience transience) {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
index 15e21e3..33fb4aa 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
@@ -453,7 +453,9 @@
       env.getListener().post(post);
     }
 
-    packageFactory.afterDoneLoadingPackage(pkg);
+    if (packageFactory != null) {
+      packageFactory.afterDoneLoadingPackage(pkg);
+    }
     return new PackageValue(pkg);
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/syntax/FunctionSignature.java b/src/main/java/com/google/devtools/build/lib/syntax/FunctionSignature.java
index 5dc2a6a..37e6f06 100644
--- a/src/main/java/com/google/devtools/build/lib/syntax/FunctionSignature.java
+++ b/src/main/java/com/google/devtools/build/lib/syntax/FunctionSignature.java
@@ -589,6 +589,10 @@
     }
   }
 
+  /** A ready-made signature to allow only positional arguments and put them in a star parameter */
+  public static final FunctionSignature POSITIONALS =
+      FunctionSignature.of(0, 0, 0, true, false, "star");
+
   /** A ready-made signature to allow only keyword arguments and put them in a kwarg parameter */
   public static final FunctionSignature KWARGS =
       FunctionSignature.of(0, 0, 0, false, true, "kwargs");
diff --git a/src/test/java/com/google/devtools/build/lib/packages/WorkspaceFactoryTest.java b/src/test/java/com/google/devtools/build/lib/packages/WorkspaceFactoryTest.java
index 4fe5b71..6a69e94 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/WorkspaceFactoryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/WorkspaceFactoryTest.java
@@ -18,6 +18,7 @@
 import static org.junit.Assert.fail;
 
 import com.google.common.collect.ImmutableList;
+import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.events.StoredEventHandler;
 import com.google.devtools.build.lib.packages.Package.Builder;
@@ -83,6 +84,31 @@
         "workspace() function should be used only at the top of the WORKSPACE file");
   }
 
+  @Test
+  public void testRegisterToolchains() throws Exception {
+    WorkspaceFactoryHelper helper = parse("register_toolchains('//toolchain:tc1')");
+    assertThat(helper.getPackage().getRegisteredToolchainLabels())
+        .containsExactly(Label.parseAbsolute("//toolchain:tc1"));
+  }
+
+  @Test
+  public void testRegisterToolchains_multipleLabels() throws Exception {
+    WorkspaceFactoryHelper helper =
+        parse("register_toolchains(", "  '//toolchain:tc1',", "  '//toolchain:tc2')");
+    assertThat(helper.getPackage().getRegisteredToolchainLabels())
+        .containsExactly(
+            Label.parseAbsolute("//toolchain:tc1"), Label.parseAbsolute("//toolchain:tc2"));
+  }
+
+  @Test
+  public void testRegisterToolchains_multipleCalls() throws Exception {
+    WorkspaceFactoryHelper helper =
+        parse("register_toolchains('//toolchain:tc1')", "register_toolchains('//toolchain:tc2')");
+    assertThat(helper.getPackage().getRegisteredToolchainLabels())
+        .containsExactly(
+            Label.parseAbsolute("//toolchain:tc1"), Label.parseAbsolute("//toolchain:tc2"));
+  }
+
   private WorkspaceFactoryHelper parse(String... args) {
     return new WorkspaceFactoryHelper(args);
   }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/ExternalPackageUtilTest.java b/src/test/java/com/google/devtools/build/lib/rules/ExternalPackageUtilTest.java
index 5902788..b6a0037 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/ExternalPackageUtilTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/ExternalPackageUtilTest.java
@@ -23,6 +23,7 @@
 import com.google.devtools.build.lib.analysis.BlazeDirectories;
 import com.google.devtools.build.lib.analysis.util.AnalysisMock;
 import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
+import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.events.NullEventHandler;
 import com.google.devtools.build.lib.packages.PackageFactory;
@@ -36,6 +37,7 @@
 import com.google.devtools.build.lib.skyframe.FileFunction;
 import com.google.devtools.build.lib.skyframe.FileStateFunction;
 import com.google.devtools.build.lib.skyframe.LocalRepositoryLookupFunction;
+import com.google.devtools.build.lib.skyframe.PackageFunction;
 import com.google.devtools.build.lib.skyframe.PackageLookupFunction;
 import com.google.devtools.build.lib.skyframe.PackageLookupFunction.CrossRepositoryLabelViolationStrategy;
 import com.google.devtools.build.lib.skyframe.PackageLookupValue.BuildFileName;
@@ -115,12 +117,15 @@
                         new PackageFactory.EmptyEnvironmentExtension()))
                 .build(ruleClassProvider, scratch.getFileSystem()),
             directories));
+    skyFunctions.put(
+        SkyFunctions.PACKAGE, new PackageFunction(null, null, null, null, null, null, null));
     skyFunctions.put(SkyFunctions.EXTERNAL_PACKAGE, new ExternalPackageFunction());
     skyFunctions.put(SkyFunctions.LOCAL_REPOSITORY_LOOKUP, new LocalRepositoryLookupFunction());
 
     // Helper Skyfunctions to call ExternalPackageUtil.
     skyFunctions.put(GET_RULE_BY_NAME_FUNCTION, new GetRuleByNameFunction());
     skyFunctions.put(GET_RULE_BY_RULE_CLASS_FUNCTION, new GetRuleByRuleClassFunction());
+    skyFunctions.put(GET_REGISTERED_TOOLCHAINS_FUNCTION, new GetRegisteredToolchainsFunction());
 
     RecordingDifferencer differencer = new RecordingDifferencer();
     MemoizingEvaluator evaluator = new InMemoryMemoizingEvaluator(skyFunctions, differencer);
@@ -134,7 +139,6 @@
       return;
     }
     scratch.overwriteFile("WORKSPACE", "http_archive(name = 'foo', url = 'http://foo')");
-    invalidatePackages(false);
 
     SkyKey key = getRuleByNameKey("foo");
     EvaluationResult<GetRuleByNameValue> result = getRuleByName(key);
@@ -152,7 +156,6 @@
       return;
     }
     scratch.overwriteFile("WORKSPACE", "http_archive(name = 'foo', url = 'http://foo')");
-    invalidatePackages(false);
 
     SkyKey key = getRuleByNameKey("bar");
     EvaluationResult<GetRuleByNameValue> result = getRuleByName(key);
@@ -173,7 +176,6 @@
         "WORKSPACE",
         "http_archive(name = 'foo', url = 'http://foo')",
         "http_archive(name = 'bar', url = 'http://bar')");
-    invalidatePackages(false);
 
     SkyKey key = getRuleByRuleClassKey("http_archive");
     EvaluationResult<GetRuleByRuleClassValue> result = getRuleByRuleClass(key);
@@ -201,7 +203,6 @@
         "WORKSPACE",
         "http_archive(name = 'foo', url = 'http://foo')",
         "http_archive(name = 'bar', url = 'http://bar')");
-    invalidatePackages(false);
 
     SkyKey key = getRuleByRuleClassKey("new_git_repository");
     EvaluationResult<GetRuleByRuleClassValue> result = getRuleByRuleClass(key);
@@ -213,6 +214,21 @@
     assertThat(rules).isEmpty();
   }
 
+  @Test
+  public void getRegisteredToolchains() throws Exception {
+    scratch.overwriteFile(
+        "WORKSPACE", "register_toolchains(", "  '//toolchain:tc1',", "  '//toolchain:tc2')");
+
+    SkyKey key = getRegisteredToolchainsKey();
+    EvaluationResult<GetRegisteredToolchainsValue> result = getRegisteredToolchains(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+
+    assertThat(result.get(key).registeredToolchainLabels())
+        .containsExactly(makeLabel("//toolchain:tc1"), makeLabel("//toolchain:tc2"))
+        .inOrder();
+  }
+
   // HELPER SKYFUNCTIONS
 
   // GetRuleByName.
@@ -310,4 +326,53 @@
       return null;
     }
   }
+
+  // GetRegisteredToolchains.
+  SkyKey getRegisteredToolchainsKey() {
+    return LegacySkyKey.create(GET_REGISTERED_TOOLCHAINS_FUNCTION, "singleton");
+  }
+
+  EvaluationResult<GetRegisteredToolchainsValue> getRegisteredToolchains(SkyKey key)
+      throws InterruptedException {
+    return driver.<GetRegisteredToolchainsValue>evaluate(
+        ImmutableList.of(key),
+        false,
+        SkyframeExecutor.DEFAULT_THREAD_COUNT,
+        NullEventHandler.INSTANCE);
+  }
+
+  private static final SkyFunctionName GET_REGISTERED_TOOLCHAINS_FUNCTION =
+      SkyFunctionName.create("GET_REGISTERED_TOOLCHAINS");
+
+  @AutoValue
+  abstract static class GetRegisteredToolchainsValue implements SkyValue {
+    abstract ImmutableList<Label> registeredToolchainLabels();
+
+    static GetRegisteredToolchainsValue create(Iterable<Label> registeredToolchainLabels) {
+      return new AutoValue_ExternalPackageUtilTest_GetRegisteredToolchainsValue(
+          ImmutableList.copyOf(registeredToolchainLabels));
+    }
+  }
+
+  private static final class GetRegisteredToolchainsFunction implements SkyFunction {
+
+    @Nullable
+    @Override
+    public SkyValue compute(SkyKey skyKey, Environment env)
+        throws SkyFunctionException, InterruptedException {
+      String ruleName = (String) skyKey.argument();
+
+      List<Label> registeredToolchainLabels = ExternalPackageUtil.getRegisteredToolchainLabels(env);
+      if (registeredToolchainLabels == null) {
+        return null;
+      }
+      return GetRegisteredToolchainsValue.create(registeredToolchainLabels);
+    }
+
+    @Nullable
+    @Override
+    public String extractTag(SkyKey skyKey) {
+      return null;
+    }
+  }
 }