Get PlatformInfo out of ToolchainResolutionKey, replace with the ConfiguredTargetKeys that own the PlatformInfo.

PiperOrigin-RevId: 185770105
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
index 6a8dcca..0211c46 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
@@ -27,9 +27,9 @@
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.packages.NoSuchThingException;
-import com.google.devtools.build.lib.skyframe.ConfiguredTargetFunction.ConfiguredValueCreationException;
 import com.google.devtools.build.lib.skyframe.RegisteredToolchainsFunction.InvalidToolchainLabelException;
 import com.google.devtools.build.lib.skyframe.ToolchainResolutionValue.ToolchainResolutionKey;
+import com.google.devtools.build.lib.skyframe.ToolchainUtil.ToolchainContextException;
 import com.google.devtools.build.lib.syntax.EvalException;
 import com.google.devtools.build.skyframe.SkyFunction;
 import com.google.devtools.build.skyframe.SkyFunctionException;
@@ -37,7 +37,9 @@
 import com.google.devtools.build.skyframe.SkyValue;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 
 /** {@link SkyFunction} which performs toolchain resolution for a class of rules. */
@@ -46,7 +48,7 @@
   @Nullable
   @Override
   public SkyValue compute(SkyKey skyKey, Environment env)
-      throws SkyFunctionException, InterruptedException {
+      throws ToolchainResolutionFunctionException, InterruptedException {
     ToolchainResolutionKey key = (ToolchainResolutionKey) skyKey.argument();
 
     // This call could be combined with the call below, but this SkyFunction is evaluated so rarely
@@ -75,13 +77,24 @@
       throw new ToolchainResolutionFunctionException(e);
     }
 
+    Map<ConfiguredTargetKey, PlatformInfo> platforms;
+    try {
+      platforms =
+          ToolchainUtil.getPlatformInfo(
+              key.targetPlatformKey(), key.availableExecutionPlatformKeys(), env);
+    } catch (ToolchainContextException e) {
+      throw new ToolchainResolutionFunctionException(e);
+    }
     // Find the right one.
     boolean debug = configuration.getOptions().get(PlatformOptions.class).toolchainResolutionDebug;
     ImmutableMap<PlatformInfo, Label> resolvedToolchainLabels =
         resolveConstraints(
             key.toolchainType(),
-            key.availableExecutionPlatforms(),
-            key.targetPlatform(),
+            key.availableExecutionPlatformKeys()
+                .stream()
+                .map(platforms::get)
+                .collect(Collectors.toList()),
+            platforms.get(key.targetPlatformKey()),
             toolchains.registeredToolchains(),
             debug ? env.getListener() : null);
 
@@ -227,7 +240,7 @@
       super(e, Transience.PERSISTENT);
     }
 
-    public ToolchainResolutionFunctionException(ConfiguredValueCreationException e) {
+    public ToolchainResolutionFunctionException(ToolchainContextException e) {
       super(e, Transience.PERSISTENT);
     }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
index f71c3bd..69ab599 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
@@ -37,10 +37,10 @@
   public static SkyKey key(
       BuildConfigurationValue.Key configurationKey,
       Label toolchainType,
-      PlatformInfo targetPlatform,
-      List<PlatformInfo> availableExecutionPlatforms) {
+      ConfiguredTargetKey targetPlatformKey,
+      List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
     return ToolchainResolutionKey.create(
-        configurationKey, toolchainType, targetPlatform, availableExecutionPlatforms);
+        configurationKey, toolchainType, targetPlatformKey, availableExecutionPlatformKeys);
   }
 
   /** {@link SkyKey} implementation used for {@link ToolchainResolutionFunction}. */
@@ -55,20 +55,20 @@
 
     public abstract Label toolchainType();
 
-    public abstract PlatformInfo targetPlatform();
+    abstract ConfiguredTargetKey targetPlatformKey();
 
-    abstract ImmutableList<PlatformInfo> availableExecutionPlatforms();
+    abstract ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys();
 
     static ToolchainResolutionKey create(
         BuildConfigurationValue.Key configuration,
         Label toolchainType,
-        PlatformInfo targetPlatform,
-        List<PlatformInfo> availableExecutionPlatforms) {
+        ConfiguredTargetKey targetPlatformKey,
+        List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
       return new AutoValue_ToolchainResolutionValue_ToolchainResolutionKey(
           configuration,
           toolchainType,
-          targetPlatform,
-          ImmutableList.copyOf(availableExecutionPlatforms));
+          targetPlatformKey,
+          ImmutableList.copyOf(availableExecutionPlatformKeys));
     }
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
index 28d5d9a..cf340d4 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
@@ -36,6 +36,7 @@
 import com.google.devtools.build.skyframe.ValueOrException;
 import com.google.devtools.build.skyframe.ValueOrException4;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -51,7 +52,7 @@
    * of the {@link ToolchainResolutionFunction}.
    */
   @Nullable
-  public static ToolchainContext createToolchainContext(
+  static ToolchainContext createToolchainContext(
       Environment env,
       String targetDescription,
       Set<Label> requiredToolchains,
@@ -79,7 +80,11 @@
 
     ImmutableBiMap<Label, Label> resolvedLabels =
         resolveToolchainLabels(
-            env, requiredToolchains, executionPlatform, targetPlatform, configurationKey);
+            env,
+            requiredToolchains,
+            configurationKey,
+            platforms.hostPlatformKey(),
+            platforms.targetPlatformKey());
     if (resolvedLabels == null) {
       return null;
     }
@@ -103,9 +108,17 @@
 
     abstract PlatformInfo targetPlatform();
 
+    abstract ConfiguredTargetKey hostPlatformKey();
+
+    abstract ConfiguredTargetKey targetPlatformKey();
+
     protected static PlatformDescriptors create(
-        PlatformInfo hostPlatform, PlatformInfo targetPlatform) {
-      return new AutoValue_ToolchainUtil_PlatformDescriptors(hostPlatform, targetPlatform);
+        PlatformInfo hostPlatform,
+        PlatformInfo targetPlatform,
+        ConfiguredTargetKey hostPlatformKey,
+        ConfiguredTargetKey targetPlatformKey) {
+      return new AutoValue_ToolchainUtil_PlatformDescriptors(
+          hostPlatform, targetPlatform, hostPlatformKey, targetPlatformKey);
     }
   }
 
@@ -136,6 +149,37 @@
   }
 
   @Nullable
+  static Map<ConfiguredTargetKey, PlatformInfo> getPlatformInfo(
+      ConfiguredTargetKey targetPlatformKey,
+      Iterable<ConfiguredTargetKey> hostPlatformKeys,
+      Environment env)
+      throws InterruptedException, ToolchainContextException {
+    Iterable<ConfiguredTargetKey> allKeys =
+        Iterables.concat(ImmutableList.of(targetPlatformKey), hostPlatformKeys);
+    Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
+        env.getValuesOrThrow(allKeys, ConfiguredValueCreationException.class);
+    boolean valuesMissing = env.valuesMissing();
+    Map<ConfiguredTargetKey, PlatformInfo> platforms = valuesMissing ? null : new HashMap<>();
+    try {
+      for (ConfiguredTargetKey key : allKeys) {
+        PlatformInfo platformInfo =
+            findPlatformInfo(
+                values.get(key),
+                key.equals(targetPlatformKey) ? "target platform" : "host platform");
+        if (!valuesMissing && platformInfo != null) {
+          platforms.put(key, platformInfo);
+        }
+      }
+    } catch (ConfiguredValueCreationException e) {
+      throw new ToolchainContextException(e);
+    }
+    if (valuesMissing) {
+      return null;
+    }
+    return platforms;
+  }
+
+  @Nullable
   private static PlatformDescriptors loadPlatformDescriptors(
       Environment env, BuildConfiguration configuration)
       throws InterruptedException, ToolchainContextException {
@@ -147,36 +191,29 @@
     Label hostPlatformLabel = platformConfiguration.getHostPlatform();
     Label targetPlatformLabel = platformConfiguration.getTargetPlatforms().get(0);
 
-    SkyKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
-    SkyKey targetPlatformKey = ConfiguredTargetKey.of(targetPlatformLabel, configuration);
-
-    Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
-        env.getValuesOrThrow(
-            ImmutableList.of(hostPlatformKey, targetPlatformKey),
-            ConfiguredValueCreationException.class);
-    boolean valuesMissing = env.valuesMissing();
-    try {
-      PlatformInfo hostPlatform = findPlatformInfo(values.get(hostPlatformKey), "host platform");
-      PlatformInfo targetPlatform =
-          findPlatformInfo(values.get(targetPlatformKey), "target platform");
-
-      if (valuesMissing) {
-        return null;
-      }
-
-      return PlatformDescriptors.create(hostPlatform, targetPlatform);
-    } catch (ConfiguredValueCreationException e) {
-      throw new ToolchainContextException(e);
+    ConfiguredTargetKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
+    ConfiguredTargetKey targetPlatformKey =
+        ConfiguredTargetKey.of(targetPlatformLabel, configuration);
+    Map<ConfiguredTargetKey, PlatformInfo> platformResult =
+        getPlatformInfo(targetPlatformKey, ImmutableList.of(hostPlatformKey), env);
+    if (env.valuesMissing()) {
+      return null;
     }
+
+    return PlatformDescriptors.create(
+        platformResult.get(hostPlatformKey),
+        platformResult.get(targetPlatformKey),
+        hostPlatformKey,
+        targetPlatformKey);
   }
 
   @Nullable
   private static ImmutableBiMap<Label, Label> resolveToolchainLabels(
       Environment env,
       Set<Label> requiredToolchains,
-      PlatformInfo executionPlatform,
-      PlatformInfo targetPlatform,
-      BuildConfigurationValue.Key configurationKey)
+      BuildConfigurationValue.Key configurationKey,
+      ConfiguredTargetKey executionPlatformKey,
+      ConfiguredTargetKey targetPlatformKey)
       throws InterruptedException, ToolchainContextException {
 
     // If there are no required toolchains, bail out early.
@@ -191,8 +228,8 @@
           ToolchainResolutionValue.key(
               configurationKey,
               toolchainType,
-              targetPlatform,
-              ImmutableList.of(executionPlatform)));
+              targetPlatformKey,
+              ImmutableList.of(executionPlatformKey)));
     }
 
     Map<
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
index 35d0705..4f8b3bc 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
@@ -16,12 +16,18 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.testing.EqualsTester;
+import com.google.devtools.build.lib.actions.Actions;
+import com.google.devtools.build.lib.analysis.ConfiguredTarget;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
 import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
+import com.google.devtools.build.lib.collect.nestedset.Order;
 import com.google.devtools.build.lib.rules.platform.ToolchainTestCase;
 import com.google.devtools.build.lib.skyframe.util.SkyframeExecutorTestUtils;
 import com.google.devtools.build.skyframe.EvaluationResult;
@@ -33,9 +39,34 @@
 /** Tests for {@link ToolchainResolutionValue} and {@link ToolchainResolutionFunction}. */
 @RunWith(JUnit4.class)
 public class ToolchainResolutionFunctionTest extends ToolchainTestCase {
+  private static final ConfiguredTargetKey LINUX_CTKEY =
+      ConfiguredTargetKey.of(Label.parseAbsoluteUnchecked("//linux:key"), null, false);
+  private static final ConfiguredTargetKey MAC_CTKEY =
+      ConfiguredTargetKey.of(Label.parseAbsoluteUnchecked("//mac:key"), null, false);
+
+  private static ConfiguredTargetValue createConfiguredTargetValue(
+      ConfiguredTarget configuredTarget) {
+    return new ConfiguredTargetValue(
+        configuredTarget,
+        new Actions.GeneratingActions(ImmutableList.of(), ImmutableMap.of()),
+        NestedSetBuilder.emptySet(Order.STABLE_ORDER),
+        /*removeActionsAfterEvaluation=*/ false);
+  }
 
   private EvaluationResult<ToolchainResolutionValue> invokeToolchainResolution(SkyKey key)
       throws InterruptedException {
+    ConfiguredTarget mockLinuxTarget = mock(ConfiguredTarget.class);
+    when(mockLinuxTarget.get(PlatformInfo.SKYLARK_CONSTRUCTOR)).thenReturn(linuxPlatform);
+    ConfiguredTarget mockMacTarget = mock(ConfiguredTarget.class);
+    when(mockMacTarget.get(PlatformInfo.SKYLARK_CONSTRUCTOR)).thenReturn(macPlatform);
+    getSkyframeExecutor()
+        .getDifferencerForTesting()
+        .inject(
+            ImmutableMap.of(
+                LINUX_CTKEY,
+                createConfiguredTargetValue(mockLinuxTarget),
+                MAC_CTKEY,
+                createConfiguredTargetValue(mockMacTarget)));
 
     try {
       getSkyframeExecutor().getSkyframeBuildView().enableAnalysis(true);
@@ -50,7 +81,7 @@
   public void testResolution_singleExecutionPlatform() throws Exception {
     SkyKey key =
         ToolchainResolutionValue.key(
-            targetConfigKey, testToolchainType, linuxPlatform, ImmutableList.of(macPlatform));
+            targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result).hasNoError();
@@ -78,8 +109,8 @@
         ToolchainResolutionValue.key(
             targetConfigKey,
             testToolchainType,
-            linuxPlatform,
-            ImmutableList.of(linuxPlatform, macPlatform));
+            LINUX_CTKEY,
+            ImmutableList.of(LINUX_CTKEY, MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result).hasNoError();
@@ -100,7 +131,7 @@
 
     SkyKey key =
         ToolchainResolutionValue.key(
-            targetConfigKey, testToolchainType, linuxPlatform, ImmutableList.of(macPlatform));
+            targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result)