Use toolchain type provider in toolchain resolution

Part of #6015.

Closes #6616.

PiperOrigin-RevId: 220465567
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/ToolchainContext.java b/src/main/java/com/google/devtools/build/lib/analysis/ToolchainContext.java
index 35c7005..368b53f 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/ToolchainContext.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/ToolchainContext.java
@@ -22,6 +22,7 @@
 import com.google.common.collect.ImmutableSet;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.LabelSyntaxException;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
@@ -31,6 +32,7 @@
 import com.google.devtools.build.lib.skylarkinterface.SkylarkPrinter;
 import com.google.devtools.build.lib.syntax.EvalException;
 import com.google.devtools.build.lib.syntax.EvalUtils;
+import java.util.Optional;
 import java.util.Set;
 import javax.annotation.Nullable;
 
@@ -57,10 +59,10 @@
     Builder setTargetPlatform(PlatformInfo targetPlatform);
 
     /** Sets the toolchain types that were requested. */
-    Builder setRequiredToolchainTypes(Set<Label> requiredToolchainTypes);
+    Builder setRequiredToolchainTypes(Set<ToolchainTypeInfo> requiredToolchainTypes);
 
     /** Sets the map from toolchain type to toolchain provider. */
-    Builder setToolchains(ImmutableMap<Label, ToolchainInfo> toolchains);
+    Builder setToolchains(ImmutableMap<ToolchainTypeInfo, ToolchainInfo> toolchains);
 
     /** Sets the template variables that these toolchains provide. */
     Builder setTemplateVariableProviders(ImmutableList<TemplateVariableInfo> providers);
@@ -82,9 +84,9 @@
   public abstract PlatformInfo targetPlatform();
 
   /** Returns the toolchain types that were requested. */
-  public abstract ImmutableSet<Label> requiredToolchainTypes();
+  public abstract ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes();
 
-  abstract ImmutableMap<Label, ToolchainInfo> toolchains();
+  abstract ImmutableMap<ToolchainTypeInfo, ToolchainInfo> toolchains();
 
   /** Returns the template variables that these toolchains provide. */
   public abstract ImmutableList<TemplateVariableInfo> templateVariableProviders();
@@ -97,7 +99,20 @@
    * required in this context.
    */
   @Nullable
-  public ToolchainInfo forToolchainType(Label toolchainType) {
+  public ToolchainInfo forToolchainType(Label toolchainTypeLabel) {
+    Optional<ToolchainTypeInfo> toolchainType =
+        toolchains().keySet().stream()
+            .filter(info -> info.typeLabel().equals(toolchainTypeLabel))
+            .findFirst();
+    if (toolchainType.isPresent()) {
+      return forToolchainType(toolchainType.get());
+    } else {
+      return null;
+    }
+  }
+
+  @Nullable
+  public ToolchainInfo forToolchainType(ToolchainTypeInfo toolchainType) {
     return toolchains().get(toolchainType);
   }
 
@@ -109,13 +124,19 @@
   @Override
   public void repr(SkylarkPrinter printer) {
     printer.append("<toolchain_context.resolved_labels: ");
-    printer.append(toolchains().keySet().stream().map(Label::toString).collect(joining(", ")));
+    printer.append(
+        toolchains().keySet().stream()
+            .map(ToolchainTypeInfo::typeLabel)
+            .map(Label::toString)
+            .collect(joining(", ")));
     printer.append(">");
   }
 
   private Label transformKey(Object key, Location loc) throws EvalException {
     if (key instanceof Label) {
       return (Label) key;
+    } else if (key instanceof ToolchainTypeInfo) {
+      return ((ToolchainTypeInfo) key).typeLabel();
     } else if (key instanceof String) {
       Label toolchainType;
       String rawLabel = (String) key;
@@ -137,23 +158,31 @@
 
   @Override
   public ToolchainInfo getIndex(Object key, Location loc) throws EvalException {
-    Label toolchainType = transformKey(key, loc);
+    Label toolchainTypeLabel = transformKey(key, loc);
 
-    if (!requiredToolchainTypes().contains(toolchainType)) {
+    if (!containsKey(key, loc)) {
       throw new EvalException(
           loc,
           String.format(
               "In %s, toolchain type %s was requested but only types [%s] are configured",
               targetDescription(),
-              toolchainType,
-              requiredToolchainTypes().stream().map(Label::toString).collect(joining())));
+              toolchainTypeLabel,
+              requiredToolchainTypes().stream()
+                  .map(ToolchainTypeInfo::typeLabel)
+                  .map(Label::toString)
+                  .collect(joining(", "))));
     }
-    return forToolchainType(toolchainType);
+    return forToolchainType(toolchainTypeLabel);
   }
 
   @Override
   public boolean containsKey(Object key, Location loc) throws EvalException {
-    Label toolchainType = transformKey(key, loc);
-    return toolchains().containsKey(toolchainType);
+    Label toolchainTypeLabel = transformKey(key, loc);
+    Optional<Label> matching =
+        toolchains().keySet().stream()
+            .map(ToolchainTypeInfo::typeLabel)
+            .filter(label -> label.equals(toolchainTypeLabel))
+            .findAny();
+    return matching.isPresent();
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/ToolchainResolver.java b/src/main/java/com/google/devtools/build/lib/analysis/ToolchainResolver.java
index c2f3819..e8c5c5d 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/ToolchainResolver.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/ToolchainResolver.java
@@ -19,7 +19,6 @@
 import static java.util.stream.Collectors.joining;
 
 import com.google.auto.value.AutoValue;
-import com.google.common.base.Joiner;
 import com.google.common.collect.HashBasedTable;
 import com.google.common.collect.ImmutableBiMap;
 import com.google.common.collect.ImmutableList;
@@ -31,6 +30,7 @@
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformProviderUtils;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.packages.Attribute;
@@ -71,7 +71,7 @@
 
   // Optional data.
   private String targetDescription = "";
-  private ImmutableSet<Label> requiredToolchainTypes = ImmutableSet.of();
+  private ImmutableSet<Label> requiredToolchainTypeLabels = ImmutableSet.of();
   private ImmutableSet<Label> execConstraintLabels = ImmutableSet.of();
 
   // Determined during execution.
@@ -98,9 +98,12 @@
     return this;
   }
 
-  /** Sets the required toolchain types that this resolver needs to find toolchains for. */
-  public ToolchainResolver setRequiredToolchainTypes(Set<Label> requiredToolchainTypes) {
-    this.requiredToolchainTypes = ImmutableSet.copyOf(requiredToolchainTypes);
+  /**
+   * Sets the labels of the required toolchain types that this resolver needs to find toolchains
+   * for.
+   */
+  public ToolchainResolver setRequiredToolchainTypes(Set<Label> requiredToolchainTypeLabels) {
+    this.requiredToolchainTypeLabels = ImmutableSet.copyOf(requiredToolchainTypeLabels);
     return this;
   }
 
@@ -137,9 +140,7 @@
     try {
       UnloadedToolchainContext.Builder unloadedToolchainContext =
           UnloadedToolchainContext.builder();
-      unloadedToolchainContext
-          .setTargetDescription(targetDescription)
-          .setRequiredToolchainTypes(requiredToolchainTypes);
+      unloadedToolchainContext.setTargetDescription(targetDescription);
 
       // Determine the configuration being used.
       BuildConfigurationValue value =
@@ -307,11 +308,11 @@
 
     // Find the toolchains for the required toolchain types.
     List<ToolchainResolutionValue.Key> registeredToolchainKeys = new ArrayList<>();
-    for (Label toolchainType : requiredToolchainTypes) {
+    for (Label toolchainTypeLabel : requiredToolchainTypeLabels) {
       registeredToolchainKeys.add(
           ToolchainResolutionValue.key(
               configurationKey,
-              toolchainType,
+              toolchainTypeLabel,
               platformKeys.targetPlatformKey(),
               platformKeys.executionPlatformKeys()));
     }
@@ -325,14 +326,14 @@
     boolean valuesMissing = false;
 
     // Determine the potential set of toolchains.
-    Table<ConfiguredTargetKey, Label, Label> resolvedToolchains = HashBasedTable.create();
+    Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains =
+        HashBasedTable.create();
+    ImmutableSet.Builder<ToolchainTypeInfo> requiredToolchainTypesBuilder = ImmutableSet.builder();
     List<Label> missingToolchains = new ArrayList<>();
     for (Map.Entry<
             SkyKey, ValueOrException2<NoToolchainFoundException, InvalidToolchainLabelException>>
         entry : results.entrySet()) {
       try {
-        Label requiredToolchainType =
-            ((ToolchainResolutionValue.Key) entry.getKey().argument()).toolchainType();
         ValueOrException2<NoToolchainFoundException, InvalidToolchainLabelException>
             valueOrException = entry.getValue();
         ToolchainResolutionValue toolchainResolutionValue =
@@ -342,11 +343,13 @@
           continue;
         }
 
+        ToolchainTypeInfo requiredToolchainType = toolchainResolutionValue.toolchainType();
+        requiredToolchainTypesBuilder.add(requiredToolchainType);
         resolvedToolchains.putAll(
             findPlatformsAndLabels(requiredToolchainType, toolchainResolutionValue));
       } catch (NoToolchainFoundException e) {
         // Save the missing type and continue looping to check for more.
-        missingToolchains.add(e.missingToolchainType());
+        missingToolchains.add(e.missingToolchainTypeLabel());
       }
     }
 
@@ -358,9 +361,11 @@
       throw new ValueMissingException();
     }
 
+    ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes = requiredToolchainTypesBuilder.build();
+
     // Find and return the first execution platform which has all required toolchains.
     Optional<ConfiguredTargetKey> selectedExecutionPlatformKey;
-    if (requiredToolchainTypes.isEmpty()
+    if (requiredToolchainTypeLabels.isEmpty()
         && platformKeys.executionPlatformKeys().contains(platformKeys.hostPlatformKey())) {
       // Fall back to the legacy behavior: use the host platform if it's available, otherwise the
       // first execution platform.
@@ -369,12 +374,12 @@
       // If there are no toolchains, this will return the first execution platform.
       selectedExecutionPlatformKey =
           findExecutionPlatformForToolchains(
-              platformKeys.executionPlatformKeys(), resolvedToolchains);
+              requiredToolchainTypes, platformKeys.executionPlatformKeys(), resolvedToolchains);
     }
 
     if (!selectedExecutionPlatformKey.isPresent()) {
       throw new NoMatchingPlatformException(
-          requiredToolchainTypes,
+          requiredToolchainTypeLabels,
           platformKeys.executionPlatformKeys(),
           platformKeys.targetPlatformKey());
     }
@@ -387,11 +392,13 @@
       throw new ValueMissingException();
     }
 
+    unloadedToolchainContext.setRequiredToolchainTypes(requiredToolchainTypes);
     unloadedToolchainContext.setExecutionPlatform(
         platforms.get(selectedExecutionPlatformKey.get()));
     unloadedToolchainContext.setTargetPlatform(platforms.get(platformKeys.targetPlatformKey()));
 
-    Map<Label, Label> toolchains = resolvedToolchains.row(selectedExecutionPlatformKey.get());
+    Map<ToolchainTypeInfo, Label> toolchains =
+        resolvedToolchains.row(selectedExecutionPlatformKey.get());
     unloadedToolchainContext.setToolchainTypeToResolved(ImmutableBiMap.copyOf(toolchains));
   }
 
@@ -399,10 +406,11 @@
    * Adds all of toolchain labels from{@code toolchainResolutionValue} to {@code
    * resolvedToolchains}.
    */
-  private static Table<ConfiguredTargetKey, Label, Label> findPlatformsAndLabels(
-      Label requiredToolchainType, ToolchainResolutionValue toolchainResolutionValue) {
+  private static Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> findPlatformsAndLabels(
+      ToolchainTypeInfo requiredToolchainType, ToolchainResolutionValue toolchainResolutionValue) {
 
-    Table<ConfiguredTargetKey, Label, Label> resolvedToolchains = HashBasedTable.create();
+    Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains =
+        HashBasedTable.create();
     for (Map.Entry<ConfiguredTargetKey, Label> entry :
         toolchainResolutionValue.availableToolchainLabels().entrySet()) {
       resolvedToolchains.put(entry.getKey(), requiredToolchainType, entry.getValue());
@@ -415,10 +423,12 @@
    * resolvedToolchains} and has all required toolchain types.
    */
   private Optional<ConfiguredTargetKey> findExecutionPlatformForToolchains(
+      ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes,
       ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
-      Table<ConfiguredTargetKey, Label, Label> resolvedToolchains) {
+      Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
     for (ConfiguredTargetKey executionPlatformKey : availableExecutionPlatformKeys) {
-      Map<Label, Label> toolchains = resolvedToolchains.row(executionPlatformKey);
+      Map<ToolchainTypeInfo, Label> toolchains = resolvedToolchains.row(executionPlatformKey);
+
       if (!toolchains.keySet().containsAll(requiredToolchainTypes)) {
         // Not all toolchains are present, keep going
         continue;
@@ -427,7 +437,10 @@
       if (debug) {
         String selectedToolchains =
             toolchains.entrySet().stream()
-                .map(e -> String.format("type %s -> toolchain %s", e.getKey(), e.getValue()))
+                .map(
+                    e ->
+                        String.format(
+                            "type %s -> toolchain %s", e.getKey().typeLabel(), e.getValue()))
                 .collect(joining(", "));
         environment
             .getListener()
@@ -467,9 +480,10 @@
       Builder setTargetPlatform(PlatformInfo targetPlatform);
 
       /** Sets the toolchain types that were requested. */
-      Builder setRequiredToolchainTypes(Set<Label> requiredToolchainTypes);
+      Builder setRequiredToolchainTypes(Set<ToolchainTypeInfo> requiredToolchainTypes);
 
-      Builder setToolchainTypeToResolved(ImmutableBiMap<Label, Label> toolchainTypeToResolved);
+      Builder setToolchainTypeToResolved(
+          ImmutableBiMap<ToolchainTypeInfo, Label> toolchainTypeToResolved);
 
       UnloadedToolchainContext build();
     }
@@ -484,10 +498,10 @@
     abstract PlatformInfo targetPlatform();
 
     /** Returns the toolchain types that were requested. */
-    abstract ImmutableSet<Label> requiredToolchainTypes();
+    abstract ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes();
 
     /** The map of toolchain type to resolved toolchain to be used. */
-    abstract ImmutableBiMap<Label, Label> toolchainTypeToResolved();
+    abstract ImmutableBiMap<ToolchainTypeInfo, Label> toolchainTypeToResolved();
 
     /** Returns the labels of the specific toolchains being used. */
     public ImmutableSet<Label> resolvedToolchainLabels() {
@@ -517,13 +531,15 @@
                   attribute ->
                       attribute.getName().equals(PlatformSemantics.RESOLVED_TOOLCHAINS_ATTR))
               .findFirst();
-      ImmutableMap.Builder<Label, ToolchainInfo> toolchains = new ImmutableMap.Builder<>();
+      ImmutableMap.Builder<ToolchainTypeInfo, ToolchainInfo> toolchains =
+          new ImmutableMap.Builder<>();
       ImmutableList.Builder<TemplateVariableInfo> templateVariableProviders =
           new ImmutableList.Builder<>();
       if (toolchainAttribute.isPresent()) {
         for (ConfiguredTargetAndData target : prerequisiteMap.get(toolchainAttribute.get())) {
           Label discoveredLabel = target.getTarget().getLabel();
-          Label toolchainType = toolchainTypeToResolved().inverse().get(discoveredLabel);
+          ToolchainTypeInfo toolchainType =
+              toolchainTypeToResolved().inverse().get(discoveredLabel);
 
           // If the toolchainType hadn't been resolved to an actual toolchain, resolution would have
           // failed with an error much earlier. This null check is just for safety.
@@ -560,17 +576,19 @@
   /** Exception used when no execution platform can be found. */
   static final class NoMatchingPlatformException extends ToolchainException {
     NoMatchingPlatformException(
-        Set<Label> requiredToolchains,
+        Set<Label> requiredToolchainTypeLabels,
         ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
         ConfiguredTargetKey targetPlatformKey) {
-      super(formatError(requiredToolchains, availableExecutionPlatformKeys, targetPlatformKey));
+      super(
+          formatError(
+              requiredToolchainTypeLabels, availableExecutionPlatformKeys, targetPlatformKey));
     }
 
     private static String formatError(
-        Set<Label> requiredToolchains,
+        Set<Label> requiredToolchainTypeLabels,
         ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
         ConfiguredTargetKey targetPlatformKey) {
-      if (requiredToolchains.isEmpty()) {
+      if (requiredToolchainTypeLabels.isEmpty()) {
         return String.format(
             "Unable to find an execution platform for target platform %s"
                 + " from available execution platforms [%s]",
@@ -582,7 +600,7 @@
       return String.format(
           "Unable to find an execution platform for toolchains [%s] and target platform %s"
               + " from available execution platforms [%s]",
-          Joiner.on(", ").join(requiredToolchains),
+          requiredToolchainTypeLabels.stream().map(Label::toString).collect(joining(", ")),
           targetPlatformKey.getLabel(),
           availableExecutionPlatformKeys.stream()
               .map(key -> key.getLabel().toString())
@@ -596,7 +614,7 @@
       super(
           String.format(
               "no matching toolchains found for types %s",
-              Joiner.on(", ").join(missingToolchainTypes)));
+              missingToolchainTypes.stream().map(Label::toString).collect(joining(", "))));
     }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
index 427ca61..61d3f25 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
@@ -26,6 +26,7 @@
 import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
 import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.events.EventHandler;
@@ -76,24 +77,13 @@
 
     // Find the right one.
     boolean debug = configuration.getOptions().get(PlatformOptions.class).toolchainResolutionDebug;
-    ImmutableMap<ConfiguredTargetKey, Label> resolvedToolchainLabels =
-        resolveConstraints(
-            key.toolchainType(),
-            key.availableExecutionPlatformKeys(),
-            key.targetPlatformKey(),
-            toolchains.registeredToolchains(),
-            env,
-            debug ? env.getListener() : null);
-    if (resolvedToolchainLabels == null) {
-      return null;
-    }
-
-    if (resolvedToolchainLabels.isEmpty()) {
-      throw new ToolchainResolutionFunctionException(
-          new NoToolchainFoundException(key.toolchainType()));
-    }
-
-    return ToolchainResolutionValue.create(resolvedToolchainLabels);
+    return resolveConstraints(
+        key.toolchainTypeLabel(),
+        key.availableExecutionPlatformKeys(),
+        key.targetPlatformKey(),
+        toolchains.registeredToolchains(),
+        env,
+        debug ? env.getListener() : null);
   }
 
   /**
@@ -102,8 +92,8 @@
    * platform.
    */
   @Nullable
-  private static ImmutableMap<ConfiguredTargetKey, Label> resolveConstraints(
-      Label toolchainType,
+  private static ToolchainResolutionValue resolveConstraints(
+      Label toolchainTypeLabel,
       List<ConfiguredTargetKey> availableExecutionPlatformKeys,
       ConfiguredTargetKey targetPlatformKey,
       ImmutableList<DeclaredToolchainInfo> toolchains,
@@ -135,11 +125,12 @@
     // check whether a platform has already been seen during processing.
     Set<ConfiguredTargetKey> platformKeysSeen = new HashSet<>();
     ImmutableMap.Builder<ConfiguredTargetKey, Label> builder = ImmutableMap.builder();
+    ToolchainTypeInfo toolchainType = null;
 
-    debugMessage(eventHandler, "Looking for toolchain of type %s...", toolchainType);
+    debugMessage(eventHandler, "Looking for toolchain of type %s...", toolchainTypeLabel);
     for (DeclaredToolchainInfo toolchain : toolchains) {
       // Make sure the type matches.
-      if (!toolchain.toolchainType().typeLabel().equals(toolchainType)) {
+      if (!toolchain.toolchainType().typeLabel().equals(toolchainTypeLabel)) {
         continue;
       }
       debugMessage(eventHandler, "  Considering toolchain %s...", toolchain.toolchainLabel());
@@ -164,6 +155,7 @@
 
         // Only add the toolchains if this is a new platform.
         if (!platformKeysSeen.contains(executionPlatformKey)) {
+          toolchainType = toolchain.toolchainType();
           builder.put(executionPlatformKey, toolchain.toolchainLabel());
           platformKeysSeen.add(executionPlatformKey);
         }
@@ -177,15 +169,18 @@
       debugMessage(
           eventHandler,
           "  For toolchain type %s, possible execution platforms and toolchains: {%s}",
-          toolchainType,
-          resolvedToolchainLabels
-              .entrySet()
-              .stream()
+          toolchainTypeLabel,
+          resolvedToolchainLabels.entrySet().stream()
               .map(e -> String.format("%s -> %s", e.getKey().getLabel(), e.getValue()))
               .collect(joining(", ")));
     }
 
-    return resolvedToolchainLabels;
+    if (toolchainType == null || resolvedToolchainLabels.isEmpty()) {
+      throw new ToolchainResolutionFunctionException(
+          new NoToolchainFoundException(toolchainTypeLabel));
+    }
+
+    return ToolchainResolutionValue.create(toolchainType, resolvedToolchainLabels);
   }
 
   /**
@@ -247,15 +242,15 @@
 
   /** Used to indicate that a toolchain was not found for the current request. */
   public static final class NoToolchainFoundException extends NoSuchThingException {
-    private final Label missingToolchainType;
+    private final Label missingToolchainTypeLabel;
 
-    public NoToolchainFoundException(Label missingToolchainType) {
-      super(String.format("no matching toolchain found for %s", missingToolchainType));
-      this.missingToolchainType = missingToolchainType;
+    public NoToolchainFoundException(Label missingToolchainTypeLabel) {
+      super(String.format("no matching toolchain found for %s", missingToolchainTypeLabel));
+      this.missingToolchainTypeLabel = missingToolchainTypeLabel;
     }
 
-    public Label missingToolchainType() {
-      return missingToolchainType;
+    public Label missingToolchainTypeLabel() {
+      return missingToolchainTypeLabel;
     }
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
index 341be5a..128269c 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
@@ -17,6 +17,7 @@
 import com.google.auto.value.AutoValue;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
 import com.google.devtools.build.skyframe.SkyFunctionName;
@@ -37,11 +38,11 @@
   // A key representing the input data.
   public static Key key(
       BuildConfigurationValue.Key configurationKey,
-      Label toolchainType,
+      Label toolchainTypeLabel,
       ConfiguredTargetKey targetPlatformKey,
       List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
     return Key.create(
-        configurationKey, toolchainType, targetPlatformKey, availableExecutionPlatformKeys);
+        configurationKey, toolchainTypeLabel, targetPlatformKey, availableExecutionPlatformKeys);
   }
 
   /** {@link SkyKey} implementation used for {@link ToolchainResolutionFunction}. */
@@ -57,7 +58,7 @@
 
     abstract BuildConfigurationValue.Key configurationKey();
 
-    public abstract Label toolchainType();
+    public abstract Label toolchainTypeLabel();
 
     abstract ConfiguredTargetKey targetPlatformKey();
 
@@ -66,12 +67,12 @@
     @AutoCodec.Instantiator
     static Key create(
         BuildConfigurationValue.Key configurationKey,
-        Label toolchainType,
+        Label toolchainTypeLabel,
         ConfiguredTargetKey targetPlatformKey,
         List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
       return new AutoValue_ToolchainResolutionValue_Key(
           configurationKey,
-          toolchainType,
+          toolchainTypeLabel,
           targetPlatformKey,
           ImmutableList.copyOf(availableExecutionPlatformKeys));
     }
@@ -79,10 +80,14 @@
 
   @AutoCodec.Instantiator
   public static ToolchainResolutionValue create(
+      ToolchainTypeInfo toolchainType,
       ImmutableMap<ConfiguredTargetKey, Label> availableToolchainLabels) {
-    return new AutoValue_ToolchainResolutionValue(availableToolchainLabels);
+    return new AutoValue_ToolchainResolutionValue(toolchainType, availableToolchainLabels);
   }
 
+  /** Returns the resolved details about the requested toolchain type. */
+  public abstract ToolchainTypeInfo toolchainType();
+
   /**
    * Returns the resolved set of toolchain labels (as {@link Label}) for the requested toolchain
    * type, keyed by the execution platforms (as {@link ConfiguredTargetKey}). Ordering is not
diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD
index 4a2b649..d29b4a7 100644
--- a/src/test/java/com/google/devtools/build/lib/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/BUILD
@@ -812,6 +812,7 @@
         "//src/main/java/com/google/devtools/build/lib:skylarkinterface",
         "//src/main/java/com/google/devtools/build/lib:util",
         "//src/main/java/com/google/devtools/build/lib/actions",
+        "//src/main/java/com/google/devtools/build/lib/analysis/platform",
         "//src/main/java/com/google/devtools/build/lib/buildeventstream",
         "//src/main/java/com/google/devtools/build/lib/buildeventstream/proto:build_event_stream_java_proto",
         "//src/main/java/com/google/devtools/build/lib/causes",
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/ToolchainResolverTest.java b/src/test/java/com/google/devtools/build/lib/analysis/ToolchainResolverTest.java
index f518354..5ceade0 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/ToolchainResolverTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/ToolchainResolverTest.java
@@ -25,6 +25,7 @@
 import com.google.devtools.build.lib.analysis.ToolchainResolver.NoMatchingPlatformException;
 import com.google.devtools.build.lib.analysis.ToolchainResolver.UnloadedToolchainContext;
 import com.google.devtools.build.lib.analysis.ToolchainResolver.UnresolvedToolchainsException;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.analysis.util.AnalysisMock;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.packages.Attribute;
@@ -97,7 +98,8 @@
 
     useConfiguration("--platforms=//platforms:linux");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
 
@@ -196,7 +198,7 @@
         ResolveToolchainsKey.create(
             "test",
             ImmutableSet.of(
-                testToolchainType, Label.parseAbsoluteUnchecked("//fake/toolchain:type_1")),
+                testToolchainTypeLabel, Label.parseAbsoluteUnchecked("//fake/toolchain:type_1")),
             targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
@@ -219,7 +221,7 @@
         ResolveToolchainsKey.create(
             "test",
             ImmutableSet.of(
-                testToolchainType,
+                testToolchainTypeLabel,
                 Label.parseAbsoluteUnchecked("//fake/toolchain:type_1"),
                 Label.parseAbsoluteUnchecked("//fake/toolchain:type_2")),
             targetConfigKey);
@@ -238,7 +240,8 @@
     scratch.file("invalid/BUILD", "filegroup(name = 'not_a_platform')");
     useConfiguration("--platforms=//invalid:not_a_platform");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
 
@@ -261,7 +264,8 @@
     scratch.resolve("invalid").delete();
     useConfiguration("--platforms=//invalid:not_a_platform");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
 
@@ -282,7 +286,8 @@
     scratch.file("invalid/BUILD", "filegroup(name = 'not_a_platform')");
     useConfiguration("--host_platform=//invalid:not_a_platform");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
 
@@ -303,7 +308,8 @@
     scratch.file("invalid/BUILD", "filegroup(name = 'not_a_platform')");
     useConfiguration("--extra_execution_platforms=//invalid:not_a_platform");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
 
@@ -343,7 +349,7 @@
     ResolveToolchainsKey key =
         ResolveToolchainsKey.create(
             "test",
-            ImmutableSet.of(testToolchainType),
+            ImmutableSet.of(testToolchainTypeLabel),
             ImmutableSet.of(Label.parseAbsoluteUnchecked("//constraints:linux")),
             targetConfigKey);
 
@@ -372,7 +378,7 @@
     ResolveToolchainsKey key =
         ResolveToolchainsKey.create(
             "test",
-            ImmutableSet.of(testToolchainType),
+            ImmutableSet.of(testToolchainTypeLabel),
             ImmutableSet.of(Label.parseAbsoluteUnchecked("//platforms:linux")),
             targetConfigKey);
 
@@ -451,7 +457,8 @@
 
     useConfiguration("--platforms=//platforms:linux");
     ResolveToolchainsKey key =
-        ResolveToolchainsKey.create("test", ImmutableSet.of(testToolchainType), targetConfigKey);
+        ResolveToolchainsKey.create(
+            "test", ImmutableSet.of(testToolchainTypeLabel), targetConfigKey);
 
     // Create the UnloadedToolchainContext.
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
@@ -479,8 +486,9 @@
   @Test
   public void unloadedToolchainContext_load_withTemplateVariables() throws Exception {
     // Add new toolchain rule that provides template variables.
-    Label variableToolchainType =
+    Label variableToolchainTypeLabel =
         Label.parseAbsoluteUnchecked("//variable:variable_toolchain_type");
+    ToolchainTypeInfo variableToolchainType = ToolchainTypeInfo.create(variableToolchainTypeLabel);
     scratch.file(
         "variable/variable_toolchain_def.bzl",
         "def _impl(ctx):",
@@ -512,7 +520,7 @@
 
     ResolveToolchainsKey key =
         ResolveToolchainsKey.create(
-            "test", ImmutableSet.of(variableToolchainType), targetConfigKey);
+            "test", ImmutableSet.of(variableToolchainTypeLabel), targetConfigKey);
 
     // Create the UnloadedToolchainContext.
     EvaluationResult<ResolveToolchainsValue> result = createToolchainContextBuilder(key);
diff --git a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
index 43750e0..7349164 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
@@ -23,6 +23,7 @@
 import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
 import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.skyframe.RegisteredToolchainsValue;
@@ -46,7 +47,8 @@
   public ConstraintValueInfo linuxConstraint;
   public ConstraintValueInfo macConstraint;
 
-  public Label testToolchainType;
+  public Label testToolchainTypeLabel;
+  public ToolchainTypeInfo testToolchainType;
 
   protected static IterableSubject assertToolchainLabels(
       RegisteredToolchainsValue registeredToolchainsValue) {
@@ -174,7 +176,8 @@
         ImmutableList.of("//constraints:linux"),
         "bar");
 
-    testToolchainType = makeLabel("//toolchain:test_toolchain");
+    testToolchainTypeLabel = makeLabel("//toolchain:test_toolchain");
+    testToolchainType = ToolchainTypeInfo.create(testToolchainTypeLabel);
   }
 
   protected EvaluationResult<RegisteredToolchainsValue> requestToolchainsFromSkyframe(
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java
index e00912f..98c40fc 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java
@@ -48,8 +48,7 @@
     // Check that the number of toolchains created for this test is correct.
     assertThat(
             value.registeredToolchains().stream()
-                .filter(
-                    toolchain -> toolchain.toolchainType().typeLabel().equals(testToolchainType))
+                .filter(toolchain -> toolchain.toolchainType().equals(testToolchainType))
                 .collect(Collectors.toList()))
         .hasSize(2);
 
@@ -57,7 +56,7 @@
             value.registeredToolchains().stream()
                 .anyMatch(
                     toolchain ->
-                        toolchain.toolchainType().typeLabel().equals(testToolchainType)
+                        toolchain.toolchainType().equals(testToolchainType)
                             && toolchain.execConstraints().get(setting).equals(linuxConstraint)
                             && toolchain.targetConstraints().get(setting).equals(macConstraint)
                             && toolchain
@@ -69,7 +68,7 @@
             value.registeredToolchains().stream()
                 .anyMatch(
                     toolchain ->
-                        toolchain.toolchainType().typeLabel().equals(testToolchainType)
+                        toolchain.toolchainType().equals(testToolchainType)
                             && toolchain.execConstraints().get(setting).equals(macConstraint)
                             && toolchain.targetConstraints().get(setting).equals(linuxConstraint)
                             && toolchain
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
index 7026749..126224a 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
@@ -89,7 +89,7 @@
   public void testResolution_singleExecutionPlatform() throws Exception {
     SkyKey key =
         ToolchainResolutionValue.key(
-            targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
+            targetConfigKey, testToolchainTypeLabel, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result).hasNoError();
@@ -116,7 +116,7 @@
     SkyKey key =
         ToolchainResolutionValue.key(
             targetConfigKey,
-            testToolchainType,
+            testToolchainTypeLabel,
             LINUX_CTKEY,
             ImmutableList.of(LINUX_CTKEY, MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
@@ -139,7 +139,7 @@
 
     SkyKey key =
         ToolchainResolutionValue.key(
-            targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
+            targetConfigKey, testToolchainTypeLabel, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result)
@@ -154,24 +154,30 @@
     new EqualsTester()
         .addEqualityGroup(
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.of(LINUX_CTKEY, makeLabel("//test:toolchain_impl_1"))),
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.of(LINUX_CTKEY, makeLabel("//test:toolchain_impl_1"))))
         // Different execution platform, same label.
         .addEqualityGroup(
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.of(MAC_CTKEY, makeLabel("//test:toolchain_impl_1"))))
         // Same execution platform, different label.
         .addEqualityGroup(
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.of(LINUX_CTKEY, makeLabel("//test:toolchain_impl_2"))))
         // Different execution platform, different label.
         .addEqualityGroup(
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.of(MAC_CTKEY, makeLabel("//test:toolchain_impl_2"))))
         // Multiple execution platforms.
         .addEqualityGroup(
             ToolchainResolutionValue.create(
+                testToolchainType,
                 ImmutableMap.<ConfiguredTargetKey, Label>builder()
                     .put(LINUX_CTKEY, makeLabel("//test:toolchain_impl_1"))
                     .put(MAC_CTKEY, makeLabel("//test:toolchain_impl_1"))