Update ToolchainUtil to properly load and use the available execution
platforms, and correctly merge together the results from TRF.

Part of #4442.

Change-Id: I31d83fa73a93d39a0e18d05a43a1c8666ac5a2d2
PiperOrigin-RevId: 187324257
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
index cf340d4..3b84c81 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
@@ -14,22 +14,27 @@
 
 package com.google.devtools.build.lib.skyframe;
 
+import static java.util.stream.Collectors.joining;
+
 import com.google.auto.value.AutoValue;
 import com.google.common.base.Joiner;
+import com.google.common.base.Optional;
+import com.google.common.collect.HashBasedTable;
 import com.google.common.collect.ImmutableBiMap;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
+import com.google.common.collect.Table;
 import com.google.devtools.build.lib.analysis.ConfiguredTarget;
 import com.google.devtools.build.lib.analysis.PlatformConfiguration;
+import com.google.devtools.build.lib.analysis.PlatformOptions;
 import com.google.devtools.build.lib.analysis.ToolchainContext;
 import com.google.devtools.build.lib.analysis.config.BuildConfiguration;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformProviderUtils;
 import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.skyframe.ConfiguredTargetFunction.ConfiguredValueCreationException;
 import com.google.devtools.build.lib.skyframe.RegisteredToolchainsFunction.InvalidToolchainLabelException;
 import com.google.devtools.build.lib.skyframe.ToolchainResolutionFunction.NoToolchainFoundException;
-import com.google.devtools.build.lib.skyframe.ToolchainResolutionValue.ToolchainResolutionKey;
 import com.google.devtools.build.lib.syntax.EvalException;
 import com.google.devtools.build.skyframe.SkyFunction.Environment;
 import com.google.devtools.build.skyframe.SkyKey;
@@ -56,8 +61,7 @@
       Environment env,
       String targetDescription,
       Set<Label> requiredToolchains,
-      @Nullable BuildConfiguration configuration,
-      BuildConfigurationValue.Key configurationKey)
+      @Nullable BuildConfigurationValue.Key configurationKey)
       throws ToolchainContextException, InterruptedException {
 
     // In some cases this is called with a missing configuration, so we skip toolchain context.
@@ -65,107 +69,105 @@
       return null;
     }
 
-    // TODO(katre): Load several possible execution platforms, and select one based on available
-    // toolchains.
+    // This call could be combined with the call below, but this SkyFunction is evaluated so rarely
+    // it's not worth optimizing.
+    BuildConfigurationValue value = (BuildConfigurationValue) env.getValue(configurationKey);
+    if (env.valuesMissing()) {
+      return null;
+    }
+    BuildConfiguration configuration = value.getConfiguration();
 
-    // Load the host and target platforms for the current configuration.
-    PlatformDescriptors platforms = loadPlatformDescriptors(env, configuration);
-    if (platforms == null) {
+    // Load the target and host platform keys.
+    PlatformConfiguration platformConfiguration =
+        configuration.getFragment(PlatformConfiguration.class);
+    if (platformConfiguration == null) {
+      return null;
+    }
+    Label hostPlatformLabel = platformConfiguration.getHostPlatform();
+    Label targetPlatformLabel = platformConfiguration.getTargetPlatforms().get(0);
+
+    ConfiguredTargetKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
+    ConfiguredTargetKey targetPlatformKey =
+        ConfiguredTargetKey.of(targetPlatformLabel, configuration);
+
+    // Load the host and target platforms early, to check for errors.
+    getPlatformInfo(ImmutableList.of(hostPlatformKey, targetPlatformKey), env);
+
+    // Load all available execution platform keys. This will find any errors in the execution
+    // platform definitions.
+    RegisteredExecutionPlatformsValue registeredExecutionPlatforms =
+        loadRegisteredExecutionPlatforms(env, configurationKey);
+    if (registeredExecutionPlatforms == null) {
       return null;
     }
 
-    // TODO(katre): This will change with remote execution.
-    PlatformInfo executionPlatform = platforms.hostPlatform();
-    PlatformInfo targetPlatform = platforms.targetPlatform();
-
-    ImmutableBiMap<Label, Label> resolvedLabels =
+    ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys =
+        new ImmutableList.Builder<ConfiguredTargetKey>()
+            .addAll(registeredExecutionPlatforms.registeredExecutionPlatformKeys())
+            .add(hostPlatformKey)
+            .build();
+    Optional<ResolvedToolchains> resolvedToolchains =
         resolveToolchainLabels(
             env,
             requiredToolchains,
+            configuration,
             configurationKey,
-            platforms.hostPlatformKey(),
-            platforms.targetPlatformKey());
-    if (resolvedLabels == null) {
+            availableExecutionPlatformKeys,
+            targetPlatformKey);
+    if (resolvedToolchains == null) {
       return null;
     }
 
-    ToolchainContext toolchainContext =
-        ToolchainContext.create(
-            targetDescription,
-            executionPlatform,
-            targetPlatform,
-            requiredToolchains,
-            resolvedLabels);
-    return toolchainContext;
-  }
-
-  /**
-   * Data class to hold platform descriptors loaded based on the current {@link BuildConfiguration}.
-   */
-  @AutoValue
-  protected abstract static class PlatformDescriptors {
-    abstract PlatformInfo hostPlatform();
-
-    abstract PlatformInfo targetPlatform();
-
-    abstract ConfiguredTargetKey hostPlatformKey();
-
-    abstract ConfiguredTargetKey targetPlatformKey();
-
-    protected static PlatformDescriptors create(
-        PlatformInfo hostPlatform,
-        PlatformInfo targetPlatform,
-        ConfiguredTargetKey hostPlatformKey,
-        ConfiguredTargetKey targetPlatformKey) {
-      return new AutoValue_ToolchainUtil_PlatformDescriptors(
-          hostPlatform, targetPlatform, hostPlatformKey, targetPlatformKey);
+    if (resolvedToolchains.isPresent()) {
+      return createContext(
+          env,
+          targetDescription,
+          resolvedToolchains.get().executionPlatformKey(),
+          resolvedToolchains.get().targetPlatformKey(),
+          requiredToolchains,
+          resolvedToolchains.get().toolchains());
+    } else {
+      // No toolchain could be resolved, but no error happened, so fall back to host platform.
+      return createContext(
+          env,
+          targetDescription,
+          hostPlatformKey,
+          targetPlatformKey,
+          requiredToolchains,
+          ImmutableBiMap.of());
     }
   }
 
-  /**
-   * Returns the {@link PlatformInfo} provider from the {@link ConfiguredTarget} in the {@link
-   * ValueOrException}, or {@code null} if the {@link ConfiguredTarget} is not present. If the
-   * {@link ConfiguredTarget} does not have a {@link PlatformInfo} provider, a {@link
-   * InvalidPlatformException} is thrown, wrapped in a {@link ToolchainContextException}.
-   */
-  @Nullable
-  private static PlatformInfo findPlatformInfo(
-      ValueOrException<ConfiguredValueCreationException> valueOrException, String platformType)
-      throws ConfiguredValueCreationException, ToolchainContextException {
-
-    ConfiguredTargetValue ctv = (ConfiguredTargetValue) valueOrException.get();
-    if (ctv == null) {
-      return null;
+  private static RegisteredExecutionPlatformsValue loadRegisteredExecutionPlatforms(
+      Environment env, BuildConfigurationValue.Key configurationKey)
+      throws InterruptedException, ToolchainContextException {
+    try {
+      RegisteredExecutionPlatformsValue registeredExecutionPlatforms =
+          (RegisteredExecutionPlatformsValue)
+              env.getValueOrThrow(
+                  RegisteredExecutionPlatformsValue.key(configurationKey),
+                  InvalidPlatformException.class);
+      if (registeredExecutionPlatforms == null) {
+        return null;
+      }
+      return registeredExecutionPlatforms;
+    } catch (InvalidPlatformException e) {
+      throw new ToolchainContextException(e);
     }
-
-    ConfiguredTarget configuredTarget = ctv.getConfiguredTarget();
-    PlatformInfo platformInfo = PlatformProviderUtils.platform(configuredTarget);
-    if (platformInfo == null) {
-      throw new ToolchainContextException(
-          new InvalidPlatformException(platformType, configuredTarget.getLabel()));
-    }
-
-    return platformInfo;
   }
 
   @Nullable
   static Map<ConfiguredTargetKey, PlatformInfo> getPlatformInfo(
-      ConfiguredTargetKey targetPlatformKey,
-      Iterable<ConfiguredTargetKey> hostPlatformKeys,
-      Environment env)
+      Iterable<ConfiguredTargetKey> platformKeys, Environment env)
       throws InterruptedException, ToolchainContextException {
-    Iterable<ConfiguredTargetKey> allKeys =
-        Iterables.concat(ImmutableList.of(targetPlatformKey), hostPlatformKeys);
+
     Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
-        env.getValuesOrThrow(allKeys, ConfiguredValueCreationException.class);
+        env.getValuesOrThrow(platformKeys, ConfiguredValueCreationException.class);
     boolean valuesMissing = env.valuesMissing();
     Map<ConfiguredTargetKey, PlatformInfo> platforms = valuesMissing ? null : new HashMap<>();
     try {
-      for (ConfiguredTargetKey key : allKeys) {
-        PlatformInfo platformInfo =
-            findPlatformInfo(
-                values.get(key),
-                key.equals(targetPlatformKey) ? "target platform" : "host platform");
+      for (ConfiguredTargetKey key : platformKeys) {
+        PlatformInfo platformInfo = findPlatformInfo(values.get(key));
         if (!valuesMissing && platformInfo != null) {
           platforms.put(key, platformInfo);
         }
@@ -179,57 +181,72 @@
     return platforms;
   }
 
+  /**
+   * Returns the {@link PlatformInfo} provider from the {@link ConfiguredTarget} in the {@link
+   * ValueOrException}, or {@code null} if the {@link ConfiguredTarget} is not present. If the
+   * {@link ConfiguredTarget} does not have a {@link PlatformInfo} provider, a {@link
+   * InvalidPlatformException} is thrown, wrapped in a {@link ToolchainContextException}.
+   */
   @Nullable
-  private static PlatformDescriptors loadPlatformDescriptors(
-      Environment env, BuildConfiguration configuration)
-      throws InterruptedException, ToolchainContextException {
-    PlatformConfiguration platformConfiguration =
-        configuration.getFragment(PlatformConfiguration.class);
-    if (platformConfiguration == null) {
-      return null;
-    }
-    Label hostPlatformLabel = platformConfiguration.getHostPlatform();
-    Label targetPlatformLabel = platformConfiguration.getTargetPlatforms().get(0);
+  private static PlatformInfo findPlatformInfo(
+      ValueOrException<ConfiguredValueCreationException> valueOrException)
+      throws ConfiguredValueCreationException, ToolchainContextException {
 
-    ConfiguredTargetKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
-    ConfiguredTargetKey targetPlatformKey =
-        ConfiguredTargetKey.of(targetPlatformLabel, configuration);
-    Map<ConfiguredTargetKey, PlatformInfo> platformResult =
-        getPlatformInfo(targetPlatformKey, ImmutableList.of(hostPlatformKey), env);
-    if (env.valuesMissing()) {
+    ConfiguredTargetValue ctv = (ConfiguredTargetValue) valueOrException.get();
+    if (ctv == null) {
       return null;
     }
 
-    return PlatformDescriptors.create(
-        platformResult.get(hostPlatformKey),
-        platformResult.get(targetPlatformKey),
-        hostPlatformKey,
-        targetPlatformKey);
+    ConfiguredTarget configuredTarget = ctv.getConfiguredTarget();
+    PlatformInfo platformInfo = PlatformProviderUtils.platform(configuredTarget);
+    if (platformInfo == null) {
+      throw new ToolchainContextException(
+          new InvalidPlatformException(configuredTarget.getLabel()));
+    }
+
+    return platformInfo;
+  }
+
+  /** Data class to hold the result of resolving toolchain labels. */
+  @AutoValue
+  protected abstract static class ResolvedToolchains {
+
+    abstract ConfiguredTargetKey executionPlatformKey();
+
+    abstract ConfiguredTargetKey targetPlatformKey();
+
+    abstract ImmutableBiMap<Label, Label> toolchains();
+
+    protected static ResolvedToolchains create(
+        ConfiguredTargetKey executionPlatformKey,
+        ConfiguredTargetKey targetPlatformKey,
+        Map<Label, Label> toolchains) {
+      return new AutoValue_ToolchainUtil_ResolvedToolchains(
+          executionPlatformKey, targetPlatformKey, ImmutableBiMap.copyOf(toolchains));
+    }
   }
 
   @Nullable
-  private static ImmutableBiMap<Label, Label> resolveToolchainLabels(
+  private static Optional<ResolvedToolchains> resolveToolchainLabels(
       Environment env,
       Set<Label> requiredToolchains,
+      BuildConfiguration configuration,
       BuildConfigurationValue.Key configurationKey,
-      ConfiguredTargetKey executionPlatformKey,
+      ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
       ConfiguredTargetKey targetPlatformKey)
       throws InterruptedException, ToolchainContextException {
 
     // If there are no required toolchains, bail out early.
     if (requiredToolchains.isEmpty()) {
-      return ImmutableBiMap.of();
+      return Optional.absent();
     }
 
     // Find the toolchains for the required toolchain types.
-    List<SkyKey> registeredToolchainKeys = new ArrayList<>();
+    List<ToolchainResolutionValue.Key> registeredToolchainKeys = new ArrayList<>();
     for (Label toolchainType : requiredToolchains) {
       registeredToolchainKeys.add(
           ToolchainResolutionValue.key(
-              configurationKey,
-              toolchainType,
-              targetPlatformKey,
-              ImmutableList.of(executionPlatformKey)));
+              configurationKey, toolchainType, targetPlatformKey, availableExecutionPlatformKeys));
     }
 
     Map<
@@ -246,8 +263,8 @@
                 EvalException.class);
     boolean valuesMissing = false;
 
-    // Load the toolchains.
-    ImmutableBiMap.Builder<Label, Label> builder = new ImmutableBiMap.Builder<>();
+    // Determine the potential set of toolchains.
+    Table<ConfiguredTargetKey, Label, Label> resolvedToolchains = HashBasedTable.create();
     List<Label> missingToolchains = new ArrayList<>();
     for (Map.Entry<
             SkyKey,
@@ -257,26 +274,19 @@
         entry : results.entrySet()) {
       try {
         Label requiredToolchainType =
-            ((ToolchainResolutionKey) entry.getKey().argument()).toolchainType();
+            ((ToolchainResolutionValue.Key) entry.getKey().argument()).toolchainType();
         ValueOrException4<
                 NoToolchainFoundException, ConfiguredValueCreationException,
                 InvalidToolchainLabelException, EvalException>
             valueOrException = entry.getValue();
         if (valueOrException.get() == null) {
           valuesMissing = true;
-        } else {
-          ToolchainResolutionValue toolchainResolutionValue =
-              (ToolchainResolutionValue) valueOrException.get();
-
-          // TODO(https://github.com/bazelbuild/bazel/issues/4442): Handle finding the best
-          // execution platform when multiple are available.
-          Label toolchainLabel =
-              Iterables.getFirst(
-                  toolchainResolutionValue.availableToolchainLabels().values(), null);
-          if (toolchainLabel != null) {
-            builder.put(requiredToolchainType, toolchainLabel);
-          }
+          continue;
         }
+
+        ToolchainResolutionValue toolchainResolutionValue =
+            (ToolchainResolutionValue) valueOrException.get();
+        addPlatformsAndLabels(resolvedToolchains, requiredToolchainType, toolchainResolutionValue);
       } catch (NoToolchainFoundException e) {
         // Save the missing type and continue looping to check for more.
         missingToolchains.add(e.missingToolchainType());
@@ -297,16 +307,89 @@
       return null;
     }
 
-    return builder.build();
+    boolean debug = configuration.getOptions().get(PlatformOptions.class).toolchainResolutionDebug;
+
+    // Find and return the first execution platform which has all required toolchains.
+    for (ConfiguredTargetKey executionPlatformKey : availableExecutionPlatformKeys) {
+      // PlatformInfo executionPlatform = platforms.get(executionPlatformKey);
+      Map<Label, Label> toolchains = resolvedToolchains.row(executionPlatformKey);
+      if (!toolchains.keySet().containsAll(requiredToolchains)) {
+        // Not all toolchains are present, keep going
+        continue;
+      }
+
+      if (debug) {
+        env.getListener()
+            .handle(
+                Event.info(
+                    String.format(
+                        "ToolchainUtil: Selected execution platform %s, %s",
+                        executionPlatformKey.getLabel(),
+                        toolchains
+                            .entrySet()
+                            .stream()
+                            .map(
+                                e ->
+                                    String.format(
+                                        "type %s -> toolchain %s", e.getKey(), e.getValue()))
+                            .collect(joining(", ")))));
+      }
+      return Optional.of(
+          ResolvedToolchains.create(executionPlatformKey, targetPlatformKey, toolchains));
+    }
+
+    return Optional.absent();
+  }
+
+  private static void addPlatformsAndLabels(
+      Table<ConfiguredTargetKey, Label, Label> resolvedToolchains,
+      Label requiredToolchainType,
+      ToolchainResolutionValue toolchainResolutionValue) {
+
+    for (Map.Entry<ConfiguredTargetKey, Label> entry :
+        toolchainResolutionValue.availableToolchainLabels().entrySet()) {
+      resolvedToolchains.put(entry.getKey(), requiredToolchainType, entry.getValue());
+    }
+  }
+
+  @Nullable
+  private static ToolchainContext createContext(
+      Environment env,
+      String targetDescription,
+      ConfiguredTargetKey executionPlatformKey,
+      ConfiguredTargetKey targetPlatformKey,
+      Set<Label> requiredToolchains,
+      ImmutableBiMap<Label, Label> toolchains)
+      throws ToolchainContextException, InterruptedException {
+
+    Map<ConfiguredTargetKey, PlatformInfo> platforms =
+        getPlatformInfo(ImmutableList.of(executionPlatformKey, targetPlatformKey), env);
+
+    if (platforms == null) {
+      return null;
+    }
+
+    return ToolchainContext.create(
+        targetDescription,
+        platforms.get(executionPlatformKey),
+        platforms.get(targetPlatformKey),
+        requiredToolchains,
+        toolchains);
   }
 
   /** Exception used when a platform label is not a valid platform. */
   static final class InvalidPlatformException extends Exception {
-    InvalidPlatformException(String platformType, Label label) {
-      super(
-          String.format(
-              "Target %s was found as the %s, but does not provide PlatformInfo",
-              label, platformType));
+    InvalidPlatformException(Label label) {
+      super(formatError(label));
+    }
+
+    InvalidPlatformException(Label label, ConfiguredValueCreationException e) {
+      super(formatError(label), e);
+    }
+
+    private static String formatError(Label label) {
+      return String.format(
+          "Target %s was referenced as a platform, but does not provide PlatformInfo", label);
     }
   }