Get PlatformInfo out of ToolchainResolutionKey, replace with the ConfiguredTargetKeys that own the PlatformInfo.

PiperOrigin-RevId: 185770105
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
index 28d5d9a..cf340d4 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
@@ -36,6 +36,7 @@
 import com.google.devtools.build.skyframe.ValueOrException;
 import com.google.devtools.build.skyframe.ValueOrException4;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -51,7 +52,7 @@
    * of the {@link ToolchainResolutionFunction}.
    */
   @Nullable
-  public static ToolchainContext createToolchainContext(
+  static ToolchainContext createToolchainContext(
       Environment env,
       String targetDescription,
       Set<Label> requiredToolchains,
@@ -79,7 +80,11 @@
 
     ImmutableBiMap<Label, Label> resolvedLabels =
         resolveToolchainLabels(
-            env, requiredToolchains, executionPlatform, targetPlatform, configurationKey);
+            env,
+            requiredToolchains,
+            configurationKey,
+            platforms.hostPlatformKey(),
+            platforms.targetPlatformKey());
     if (resolvedLabels == null) {
       return null;
     }
@@ -103,9 +108,17 @@
 
     abstract PlatformInfo targetPlatform();
 
+    abstract ConfiguredTargetKey hostPlatformKey();
+
+    abstract ConfiguredTargetKey targetPlatformKey();
+
     protected static PlatformDescriptors create(
-        PlatformInfo hostPlatform, PlatformInfo targetPlatform) {
-      return new AutoValue_ToolchainUtil_PlatformDescriptors(hostPlatform, targetPlatform);
+        PlatformInfo hostPlatform,
+        PlatformInfo targetPlatform,
+        ConfiguredTargetKey hostPlatformKey,
+        ConfiguredTargetKey targetPlatformKey) {
+      return new AutoValue_ToolchainUtil_PlatformDescriptors(
+          hostPlatform, targetPlatform, hostPlatformKey, targetPlatformKey);
     }
   }
 
@@ -136,6 +149,37 @@
   }
 
   @Nullable
+  static Map<ConfiguredTargetKey, PlatformInfo> getPlatformInfo(
+      ConfiguredTargetKey targetPlatformKey,
+      Iterable<ConfiguredTargetKey> hostPlatformKeys,
+      Environment env)
+      throws InterruptedException, ToolchainContextException {
+    Iterable<ConfiguredTargetKey> allKeys =
+        Iterables.concat(ImmutableList.of(targetPlatformKey), hostPlatformKeys);
+    Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
+        env.getValuesOrThrow(allKeys, ConfiguredValueCreationException.class);
+    boolean valuesMissing = env.valuesMissing();
+    Map<ConfiguredTargetKey, PlatformInfo> platforms = valuesMissing ? null : new HashMap<>();
+    try {
+      for (ConfiguredTargetKey key : allKeys) {
+        PlatformInfo platformInfo =
+            findPlatformInfo(
+                values.get(key),
+                key.equals(targetPlatformKey) ? "target platform" : "host platform");
+        if (!valuesMissing && platformInfo != null) {
+          platforms.put(key, platformInfo);
+        }
+      }
+    } catch (ConfiguredValueCreationException e) {
+      throw new ToolchainContextException(e);
+    }
+    if (valuesMissing) {
+      return null;
+    }
+    return platforms;
+  }
+
+  @Nullable
   private static PlatformDescriptors loadPlatformDescriptors(
       Environment env, BuildConfiguration configuration)
       throws InterruptedException, ToolchainContextException {
@@ -147,36 +191,29 @@
     Label hostPlatformLabel = platformConfiguration.getHostPlatform();
     Label targetPlatformLabel = platformConfiguration.getTargetPlatforms().get(0);
 
-    SkyKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
-    SkyKey targetPlatformKey = ConfiguredTargetKey.of(targetPlatformLabel, configuration);
-
-    Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
-        env.getValuesOrThrow(
-            ImmutableList.of(hostPlatformKey, targetPlatformKey),
-            ConfiguredValueCreationException.class);
-    boolean valuesMissing = env.valuesMissing();
-    try {
-      PlatformInfo hostPlatform = findPlatformInfo(values.get(hostPlatformKey), "host platform");
-      PlatformInfo targetPlatform =
-          findPlatformInfo(values.get(targetPlatformKey), "target platform");
-
-      if (valuesMissing) {
-        return null;
-      }
-
-      return PlatformDescriptors.create(hostPlatform, targetPlatform);
-    } catch (ConfiguredValueCreationException e) {
-      throw new ToolchainContextException(e);
+    ConfiguredTargetKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
+    ConfiguredTargetKey targetPlatformKey =
+        ConfiguredTargetKey.of(targetPlatformLabel, configuration);
+    Map<ConfiguredTargetKey, PlatformInfo> platformResult =
+        getPlatformInfo(targetPlatformKey, ImmutableList.of(hostPlatformKey), env);
+    if (env.valuesMissing()) {
+      return null;
     }
+
+    return PlatformDescriptors.create(
+        platformResult.get(hostPlatformKey),
+        platformResult.get(targetPlatformKey),
+        hostPlatformKey,
+        targetPlatformKey);
   }
 
   @Nullable
   private static ImmutableBiMap<Label, Label> resolveToolchainLabels(
       Environment env,
       Set<Label> requiredToolchains,
-      PlatformInfo executionPlatform,
-      PlatformInfo targetPlatform,
-      BuildConfigurationValue.Key configurationKey)
+      BuildConfigurationValue.Key configurationKey,
+      ConfiguredTargetKey executionPlatformKey,
+      ConfiguredTargetKey targetPlatformKey)
       throws InterruptedException, ToolchainContextException {
 
     // If there are no required toolchains, bail out early.
@@ -191,8 +228,8 @@
           ToolchainResolutionValue.key(
               configurationKey,
               toolchainType,
-              targetPlatform,
-              ImmutableList.of(executionPlatform)));
+              targetPlatformKey,
+              ImmutableList.of(executionPlatformKey)));
     }
 
     Map<