ConfiguredTargetFunction should use the given ToolchainContextKey, if available.

Part of work on toolchain transitions, #10523.

PiperOrigin-RevId: 314605757
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ConfiguredTargetFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ConfiguredTargetFunction.java
index 2764e94..7d3ecb6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ConfiguredTargetFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ConfiguredTargetFunction.java
@@ -291,7 +291,12 @@
 
       // Determine what toolchains are needed by this target.
       unloadedToolchainContexts =
-          computeUnloadedToolchainContexts(env, ruleClassProvider, defaultBuildOptions, ctgValue);
+          computeUnloadedToolchainContexts(
+              env,
+              ruleClassProvider,
+              defaultBuildOptions,
+              ctgValue,
+              configuredTargetKey.getToolchainContextKey());
       if (env.valuesMissing()) {
         return null;
       }
@@ -441,7 +446,8 @@
       Environment env,
       RuleClassProvider ruleClassProvider,
       BuildOptions defaultBuildOptions,
-      TargetAndConfiguration targetAndConfig)
+      TargetAndConfiguration targetAndConfig,
+      @Nullable ToolchainContextKey parentToolchainContextKey)
       throws InterruptedException, ToolchainException {
     if (!(targetAndConfig.getTarget() instanceof Rule)) {
       return null;
@@ -494,14 +500,19 @@
 
     Map<String, ToolchainContextKey> toolchainContextKeys = new HashMap<>();
     String targetUnloadedToolchainContext = "target-unloaded-toolchain-context";
-    toolchainContextKeys.put(
-        targetUnloadedToolchainContext,
-        ToolchainContextKey.key()
-            .configurationKey(toolchainConfig)
-            .requiredToolchainTypeLabels(requiredDefaultToolchains)
-            .execConstraintLabels(defaultExecConstraintLabels)
-            .shouldSanityCheckConfiguration(configuration.trimConfigurationsRetroactively())
-            .build());
+    ToolchainContextKey toolchainContextKey;
+    if (parentToolchainContextKey != null) {
+      toolchainContextKey = parentToolchainContextKey;
+    } else {
+      toolchainContextKey =
+          ToolchainContextKey.key()
+              .configurationKey(toolchainConfig)
+              .requiredToolchainTypeLabels(requiredDefaultToolchains)
+              .execConstraintLabels(defaultExecConstraintLabels)
+              .shouldSanityCheckConfiguration(configuration.trimConfigurationsRetroactively())
+              .build();
+    }
+    toolchainContextKeys.put(targetUnloadedToolchainContext, toolchainContextKey);
     for (Map.Entry<String, ExecGroup> group : execGroups.entrySet()) {
       ExecGroup execGroup = group.getValue();
       toolchainContextKeys.put(
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainsForTargetsTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainsForTargetsTest.java
index 3869019..e0e29e8 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainsForTargetsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainsForTargetsTest.java
@@ -25,6 +25,7 @@
 import com.google.devtools.build.lib.analysis.config.BuildOptions;
 import com.google.devtools.build.lib.analysis.util.AnalysisMock;
 import com.google.devtools.build.lib.analysis.util.AnalysisTestCase;
+import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.packages.RuleClassProvider;
 import com.google.devtools.build.lib.skyframe.util.SkyframeExecutorTestUtils;
 import com.google.devtools.build.lib.testutil.Suite;
@@ -122,8 +123,8 @@
                 env,
                 stateProvider.lateBoundRuleClassProvider(),
                 buildOptionsSupplier.get(),
-                key.targetAndConfiguration());
-        // TODO(#10523): Pass in the ToolchainContextKey.
+                key.targetAndConfiguration(),
+                key.configuredTargetKey().getToolchainContextKey());
         return env.valuesMissing() ? null : Value.create(toolchainCollection);
       } catch (ToolchainException e) {
         throw new ComputeUnloadedToolchainContextsException(e);
@@ -447,4 +448,47 @@
         .execGroup("temp")
         .hasResolvedToolchain("//toolchains:toolchain_1_impl");
   }
+
+  @Test
+  public void keepParentToolchainContext() throws Exception {
+    scratch.file(
+        "extra/BUILD",
+        "load('//toolchain:toolchain_def.bzl', 'test_toolchain')",
+        "toolchain_type(name = 'extra_toolchain')",
+        "toolchain(",
+        "    name = 'toolchain',",
+        "    toolchain_type = '//extra:extra_toolchain',",
+        "    exec_compatible_with = [],",
+        "    target_compatible_with = [],",
+        "    toolchain = ':toolchain_impl')",
+        "test_toolchain(",
+        "    name='toolchain_impl',",
+        "    data = 'foo')");
+    scratch.file("a/BUILD", "load('//toolchain:rule.bzl', 'my_rule')", "my_rule(name = 'a')");
+
+    useConfiguration("--extra_toolchains=//extra:toolchain");
+    ConfiguredTarget target = Iterables.getOnlyElement(update("//a").getTargetsToBuild());
+    ToolchainCollection<UnloadedToolchainContext> toolchainCollection =
+        getToolchainCollection(
+            target,
+            ConfiguredTargetKey.builder()
+                .setLabel(target.getOriginalLabel())
+                .setConfigurationKey(target.getConfigurationKey())
+                .setToolchainContextKey(
+                    ToolchainContextKey.key()
+                        .configurationKey(target.getConfigurationKey())
+                        .requiredToolchainTypeLabels(
+                            Label.parseAbsoluteUnchecked("//extra:extra_toolchain"))
+                        .build())
+                .build());
+
+    assertThat(toolchainCollection).isNotNull();
+    assertThat(toolchainCollection).hasDefaultExecGroup();
+    assertThat(toolchainCollection)
+        .defaultToolchainContext()
+        .hasToolchainType("//extra:extra_toolchain");
+    assertThat(toolchainCollection)
+        .defaultToolchainContext()
+        .hasResolvedToolchain("//extra:toolchain_impl");
+  }
 }