Move `isPlatformSuitable` check into PlatformKeys.

PiperOrigin-RevId: 690606178
Change-Id: I4e3e83ac8399b75a27508a782d5fad92bd190281
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/PlatformKeys.java b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/PlatformKeys.java
index 931c15b..dc70328 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/PlatformKeys.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/PlatformKeys.java
@@ -14,13 +14,17 @@
 package com.google.devtools.build.lib.skyframe.toolchains;
 
 import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Table;
 import com.google.devtools.build.lib.analysis.PlatformConfiguration;
 import com.google.devtools.build.lib.analysis.config.CommonOptions;
 import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.skyframe.ConfiguredTargetKey;
@@ -40,7 +44,8 @@
 /** Details of platforms used during toolchain resolution. */
 record PlatformKeys(
     ConfiguredTargetKey targetPlatformKey,
-    ImmutableList<ConfiguredTargetKey> executionPlatformKeys) {
+    ImmutableList<ConfiguredTargetKey> executionPlatformKeys,
+    ImmutableMap<ConfiguredTargetKey, PlatformInfo> platformInfos) {
 
   private static class Builder {
     // Input data.
@@ -108,7 +113,8 @@
       ImmutableList<ConfiguredTargetKey> executionPlatformKeys =
           filterExecutionPlatforms(execConstraintLabels);
 
-      return new PlatformKeys(resolvedTargetPlatformKey, executionPlatformKeys);
+      return new PlatformKeys(
+          resolvedTargetPlatformKey, executionPlatformKeys, ImmutableMap.copyOf(platformInfos));
     }
 
     private void findExecutionPlatformKeys()
@@ -264,4 +270,34 @@
 
     return null;
   }
+
+  @Nullable
+  PlatformInfo platformInfo(ConfiguredTargetKey configuredTargetKey) {
+    return platformInfos.get(configuredTargetKey);
+  }
+
+  public boolean isPlatformSuitable(
+      ConfiguredTargetKey executionPlatformKey,
+      ImmutableSet<ToolchainResolutionFunction.ToolchainType> toolchainTypes,
+      Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
+    PlatformInfo executionPlatformInfo = platformInfo(executionPlatformKey);
+    if (executionPlatformInfo.checkToolchainTypes() && toolchainTypes.isEmpty()) {
+      // This can't be suitable.
+      return false;
+    } else if (toolchainTypes.isEmpty()) {
+      // Since there aren't any toolchains, we should be able to use any execution platform that
+      // has made it this far.
+      return true;
+    }
+
+    // Determine whether all mandatory toolchains are present.
+    return resolvedToolchains
+        .row(executionPlatformKey)
+        .keySet()
+        .containsAll(
+            toolchainTypes.stream()
+                .filter(ToolchainResolutionFunction.ToolchainType::mandatory)
+                .map(ToolchainResolutionFunction.ToolchainType::toolchainTypeInfo)
+                .collect(toImmutableSet()));
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunction.java
index ccef6ef..c536099 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunction.java
@@ -144,10 +144,10 @@
     }
   }
 
-  private record ToolchainType(
+  record ToolchainType(
       ToolchainTypeRequirement toolchainTypeRequirement, ToolchainTypeInfo toolchainTypeInfo) {
 
-    private ToolchainType {
+    ToolchainType {
       Objects.requireNonNull(toolchainTypeRequirement, "toolchainTypeRequirement");
       Objects.requireNonNull(toolchainTypeInfo, "toolchainTypeInfo");
     }
@@ -274,10 +274,7 @@
     // Find and return the first execution platform which has all mandatory toolchains.
     Optional<ConfiguredTargetKey> selectedExecutionPlatformKey =
         findExecutionPlatformForToolchains(
-            toolchainTypes,
-            forcedExecutionPlatform,
-            platformKeys.executionPlatformKeys(),
-            resolvedToolchains);
+            toolchainTypes, forcedExecutionPlatform, platformKeys, resolvedToolchains);
 
     ImmutableSet<ToolchainTypeRequirement> toolchainTypeRequirements =
         toolchainTypes.stream()
@@ -337,19 +334,21 @@
   private static Optional<ConfiguredTargetKey> findExecutionPlatformForToolchains(
       ImmutableSet<ToolchainType> toolchainTypes,
       Optional<ConfiguredTargetKey> forcedExecutionPlatform,
-      ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
+      PlatformKeys platformKeys,
       Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
 
     if (forcedExecutionPlatform.isPresent()) {
       // Is the forced platform suitable?
-      if (isPlatformSuitable(forcedExecutionPlatform.get(), toolchainTypes, resolvedToolchains)) {
+      if (platformKeys.isPlatformSuitable(
+          forcedExecutionPlatform.get(), toolchainTypes, resolvedToolchains)) {
         return forcedExecutionPlatform;
       }
     }
 
     var candidatePlatforms =
-        availableExecutionPlatformKeys.stream()
-            .filter(epk -> isPlatformSuitable(epk, toolchainTypes, resolvedToolchains));
+        platformKeys.executionPlatformKeys().stream()
+            .filter(
+                epk -> platformKeys.isPlatformSuitable(epk, toolchainTypes, resolvedToolchains));
 
     var toolchainTypeInfos =
         toolchainTypes.stream().map(ToolchainType::toolchainTypeInfo).collect(toImmutableSet());
@@ -360,27 +359,6 @@
             epk -> countToolchainsOnPlatform(epk, toolchainTypeInfos, resolvedToolchains)));
   }
 
-  private static boolean isPlatformSuitable(
-      ConfiguredTargetKey executionPlatformKey,
-      ImmutableSet<ToolchainType> toolchainTypes,
-      Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
-    if (toolchainTypes.isEmpty()) {
-      // Since there aren't any toolchains, we should be able to use any execution platform that
-      // has made it this far.
-      return true;
-    }
-
-    // Determine whether all mandatory toolchains are present.
-    return resolvedToolchains
-        .row(executionPlatformKey)
-        .keySet()
-        .containsAll(
-            toolchainTypes.stream()
-                .filter(ToolchainType::mandatory)
-                .map(ToolchainType::toolchainTypeInfo)
-                .collect(toImmutableSet()));
-  }
-
   private static long countToolchainsOnPlatform(
       ConfiguredTargetKey executionPlatformKey,
       ImmutableSet<ToolchainTypeInfo> toolchainTypeInfos,
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunctionTest.java
index 71a925a..e6eceea 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/ToolchainResolutionFunctionTest.java
@@ -513,6 +513,41 @@
   }
 
   @Test
+  public void resolve_noToolchainType_checkPlatformAllowedToolchains() throws Exception {
+    // Define two new execution platforms, only one of which is compatible with the test toolchain.
+    scratch.file(
+        "allowed/BUILD",
+        """
+        platform(
+            name = "fails_match",
+            check_toolchain_types = True,
+            allowed_toolchain_types = [
+                # Empty, so doesn't match anything.
+            ],
+        )
+
+        platform(
+            name = "allows_all",
+        )
+        """);
+    rewriteModuleDotBazel(
+        "register_execution_platforms('//allowed:fails_match', '//allowed:allows_all')");
+
+    useConfiguration("--host_platform=//allowed:fails_match");
+    ToolchainContextKey key = ToolchainContextKey.key().configurationKey(targetConfigKey).build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext.toolchainTypes()).isEmpty();
+    // Even with no toolchains requested, should still select the first execution platform.
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//allowed:allows_all");
+  }
+
+  @Test
   public void resolve_noToolchainType_hostNotAvailable() throws Exception {
     scratch.file("host/BUILD", "platform(name = 'host')");
     scratch.file(