Allow starlark transitions to whitelist based on rule definition location.

RELNOTES: None.
PiperOrigin-RevId: 240793931
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/Whitelist.java b/src/main/java/com/google/devtools/build/lib/analysis/Whitelist.java
index 2c35fee..122d016 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/Whitelist.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/Whitelist.java
@@ -52,21 +52,44 @@
   }
 
   /**
-   * Returns whether the rule in the given RuleContext is in a whitelist.
+   * Returns whether the rule in the given RuleContext *was defined* in a whitelist.
+   *
+   * @param ruleContext The context in which this check is being executed.
+   * @param whitelistName The name of the whitelist being used.
+   */
+  public static boolean isAvailableBasedOnRuleLocation(
+      RuleContext ruleContext, String whitelistName) {
+    return isAvailableFor(
+        ruleContext,
+        whitelistName,
+        ruleContext.getRule().getRuleClassObject().getRuleDefinitionEnvironmentLabel());
+  }
+
+  /**
+   * Returns whether the rule in the given RuleContext *was instantiated* in a whitelist.
    *
    * @param ruleContext The context in which this check is being executed.
    * @param whitelistName The name of the whitelist being used.
    */
   public static boolean isAvailable(RuleContext ruleContext, String whitelistName) {
+    return isAvailableFor(ruleContext, whitelistName, ruleContext.getLabel());
+  }
+
+  /**
+   * @param relevantLabel the label to check for in the whitelist. This allows features that
+   *     whitelist on rule definition location and features that whitelist on rule instantiation
+   *     location to share logic.
+   */
+  private static boolean isAvailableFor(
+      RuleContext ruleContext, String whitelistName, Label relevantLabel) {
     String attributeName = getAttributeNameFromWhitelistName(whitelistName);
     Preconditions.checkArgument(ruleContext.isAttrDefined(attributeName, LABEL));
     TransitiveInfoCollection packageGroup = ruleContext.getPrerequisite(attributeName, Mode.HOST);
-    Label label = ruleContext.getLabel();
     PackageSpecificationProvider packageSpecificationProvider =
         packageGroup.getProvider(PackageSpecificationProvider.class);
     requireNonNull(packageSpecificationProvider, packageGroup.getLabel().toString());
     return Streams.stream(packageSpecificationProvider.getPackageSpecifications())
-        .anyMatch(p -> p.containsPackage(label.getPackageIdentifier()));
+        .anyMatch(p -> p.containsPackage(relevantLabel.getPackageIdentifier()));
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleConfiguredTargetUtil.java b/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleConfiguredTargetUtil.java
index 1e5626d..a3e9460 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleConfiguredTargetUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleConfiguredTargetUtil.java
@@ -117,8 +117,11 @@
         return null;
       }
       if (ruleClass.hasFunctionTransitionWhitelist()
-          && !Whitelist.isAvailable(ruleContext, FunctionSplitTransitionWhitelist.WHITELIST_NAME)) {
-        ruleContext.ruleError("Non-whitelisted use of Starlark transition");
+          && !Whitelist.isAvailableBasedOnRuleLocation(
+              ruleContext, FunctionSplitTransitionWhitelist.WHITELIST_NAME)) {
+        if (!Whitelist.isAvailable(ruleContext, FunctionSplitTransitionWhitelist.WHITELIST_NAME)) {
+          ruleContext.ruleError("Non-whitelisted use of Starlark transition");
+        }
       }
 
       Object target =
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/StarlarkAttrTransitionProviderTest.java b/src/test/java/com/google/devtools/build/lib/analysis/StarlarkAttrTransitionProviderTest.java
index a8d8132..e6e6e12 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/StarlarkAttrTransitionProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/StarlarkAttrTransitionProviderTest.java
@@ -113,11 +113,38 @@
   }
 
   @Test
-  public void testTargetNotInWhitelist() throws Exception {
-    writeBasicTestFiles();
+  public void testTargetAndRuleNotInWhitelist() throws Exception {
+    setSkylarkSemanticsOptions("--experimental_starlark_config_transitions=true");
+    writeWhitelistFile();
+    getAnalysisMock().ccSupport().setupCcToolchainConfigForCpu(mockToolsConfig, "armeabi-v7a");
+    scratch.file(
+        "test/not_whitelisted/my_rule.bzl",
+        "def transition_func(settings, attr):",
+        "  return {",
+        "      't0': {'//command_line_option:cpu': 'k8'},",
+        "      't1': {'//command_line_option:cpu': 'armeabi-v7a'},",
+        "  }",
+        "my_transition = transition(implementation = transition_func, inputs = [],",
+        "  outputs = ['//command_line_option:cpu'])",
+        "def impl(ctx): ",
+        "  return struct(",
+        "    split_attr_deps = ctx.split_attr.deps,",
+        "    split_attr_dep = ctx.split_attr.dep,",
+        "    k8_deps = ctx.split_attr.deps.get('k8', None),",
+        "    attr_deps = ctx.attr.deps,",
+        "    attr_dep = ctx.attr.dep)",
+        "my_rule = rule(",
+        "  implementation = impl,",
+        "  attrs = {",
+        "    'deps': attr.label_list(cfg = my_transition),",
+        "    'dep':  attr.label(cfg = my_transition),",
+        "    '_whitelist_function_transition': attr.label(",
+        "        default = '//tools/whitelists/function_transition_whitelist',",
+        "    ),",
+        "  })");
     scratch.file(
         "test/not_whitelisted/BUILD",
-        "load('//test/skylark:my_rule.bzl', 'my_rule')",
+        "load('//test/not_whitelisted:my_rule.bzl', 'my_rule')",
         "my_rule(name = 'test', dep = ':main')",
         "cc_binary(name = 'main', srcs = ['main.c'])");
 
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/StarlarkRuleTransitionProviderTest.java b/src/test/java/com/google/devtools/build/lib/analysis/StarlarkRuleTransitionProviderTest.java
index 5b1e701..a747d37 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/StarlarkRuleTransitionProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/StarlarkRuleTransitionProviderTest.java
@@ -659,4 +659,71 @@
     assertThat(configuration.getOptions().get(TestOptions.class).testArguments)
         .containsExactly("post-transition");
   }
+
+  @Test
+  public void testWhitelistOnRuleNotTargets() throws Exception {
+    // whitelists //test/...
+    writeWhitelistFile();
+    scratch.file(
+        "test/transitions.bzl",
+        "def _impl(settings, attr):",
+        "  return {'//command_line_option:test_arg': ['post-transition']}",
+        "my_transition = transition(implementation = _impl, inputs = [],",
+        "  outputs = ['//command_line_option:test_arg'])");
+    scratch.file(
+        "test/rules.bzl",
+        "load('//test:transitions.bzl', 'my_transition')",
+        "def _impl(ctx):",
+        "  return []",
+        "my_rule = rule(",
+        "  implementation = _impl,",
+        "  cfg = my_transition,",
+        "  attrs = {",
+        "    '_whitelist_function_transition': attr.label(",
+        "        default = '//tools/whitelists/function_transition_whitelist',",
+        "    ),",
+        "  })");
+    scratch.file(
+        "neverland/BUILD", "load('//test:rules.bzl', 'my_rule')", "my_rule(name = 'test')");
+    scratch.file("test/BUILD");
+    useConfiguration("--test_arg=pre-transition");
+
+    BuildConfiguration configuration = getConfiguration(getConfiguredTarget("//neverland:test"));
+    assertThat(configuration.getOptions().get(TestOptions.class).testArguments)
+        .containsExactly("post-transition");
+  }
+
+  // TODO(juliexxia): flip this test when this isn't allowed anymore.
+  @Test
+  public void testWhitelistOnTargetsStillWorks() throws Exception {
+    // whitelists //test/...
+    writeWhitelistFile();
+    scratch.file(
+        "neverland/transitions.bzl",
+        "def _impl(settings, attr):",
+        "  return {'//command_line_option:test_arg': ['post-transition']}",
+        "my_transition = transition(implementation = _impl, inputs = [],",
+        "  outputs = ['//command_line_option:test_arg'])");
+    scratch.file(
+        "neverland/rules.bzl",
+        "load('//neverland:transitions.bzl', 'my_transition')",
+        "def _impl(ctx):",
+        "  return []",
+        "my_rule = rule(",
+        "  implementation = _impl,",
+        "  cfg = my_transition,",
+        "  attrs = {",
+        "    '_whitelist_function_transition': attr.label(",
+        "        default = '//tools/whitelists/function_transition_whitelist',",
+        "    ),",
+        "  })");
+    scratch.file(
+        "test/BUILD", "load('//neverland:rules.bzl', 'my_rule')", "my_rule(name = 'test')");
+    scratch.file("neverland/BUILD");
+    useConfiguration("--test_arg=pre-transition");
+
+    BuildConfiguration configuration = getConfiguration(getConfiguredTarget("//test"));
+    assertThat(configuration.getOptions().get(TestOptions.class).testArguments)
+        .containsExactly("post-transition");
+  }
 }