Update starlark apis that consume toolchain types to use

ToolchainTypeRequirement.

Part of Optional Toolchains (#14726).

Closes #14948.

PiperOrigin-RevId: 439641733
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/BUILD b/src/main/java/com/google/devtools/build/lib/analysis/BUILD
index 908a2c5..ff8ac37 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/analysis/BUILD
@@ -1837,6 +1837,7 @@
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/starlarkbuildapi/config:starlark_toolchain_type_requirement",
         "//third_party:auto_value",
+        "//third_party:guava",
     ],
 )
 
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/config/ToolchainTypeRequirement.java b/src/main/java/com/google/devtools/build/lib/analysis/config/ToolchainTypeRequirement.java
index 4ffba00..6b22dd6 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/config/ToolchainTypeRequirement.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/config/ToolchainTypeRequirement.java
@@ -14,6 +14,7 @@
 package com.google.devtools.build.lib.analysis.config;
 
 import com.google.auto.value.AutoValue;
+import com.google.common.base.Preconditions;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.starlarkbuildapi.config.StarlarkToolchainTypeRequirement;
 
@@ -33,6 +34,24 @@
         .mandatory(true);
   }
 
+  /**
+   * Returns the ToolchainTypeRequirement with the strictest restriction, or else the first.
+   * Mandatory toolchain type requirements are stricter than optional.
+   */
+  public static ToolchainTypeRequirement strictest(
+      ToolchainTypeRequirement first, ToolchainTypeRequirement second) {
+    Preconditions.checkArgument(
+        first.toolchainType().equals(second.toolchainType()),
+        "Cannot use strictest() for two instances with different type labels.");
+    if (first.mandatory()) {
+      return first;
+    }
+    if (second.mandatory()) {
+      return second;
+    }
+    return first;
+  }
+
   /** Returns the label of the toolchain type that is requested. */
   public abstract Label toolchainType();
 
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
index 01721f2..b869884 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
@@ -14,7 +14,6 @@
 
 package com.google.devtools.build.lib.analysis.starlark;
 
-import static com.google.common.collect.ImmutableSet.toImmutableSet;
 import static com.google.devtools.build.lib.analysis.BaseRuleClasses.RUN_UNDER;
 import static com.google.devtools.build.lib.analysis.BaseRuleClasses.TEST_RUNNER_EXEC_GROUP;
 import static com.google.devtools.build.lib.analysis.BaseRuleClasses.TIMEOUT_DEFAULT;
@@ -102,6 +101,7 @@
 import com.google.devtools.build.lib.util.FileTypeSet;
 import com.google.devtools.build.lib.util.Pair;
 import com.google.errorprone.annotations.FormatMethod;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
 import javax.annotation.Nullable;
@@ -401,7 +401,7 @@
             : Label.createUnvalidated(PackageIdentifier.EMPTY_PACKAGE_ID, "dummy_label"),
         bzlModule != null ? bzlModule.bzlTransitiveDigest() : new byte[0]);
 
-    builder.addRequiredToolchains(parseToolchains(toolchains, thread));
+    builder.addToolchainTypes(parseToolchainTypes(toolchains, thread));
     if (useToolchainTransition) {
       builder.useToolchainTransition(ToolchainTransitionMode.ENABLED);
     }
@@ -559,41 +559,21 @@
     return attributes.build();
   }
 
-  /**
-   * Parses a sequence of label strings with a repo mapping.
-   *
-   * @param inputs sequence of input strings
-   * @param thread repository mapping
-   * @param adjective describes the purpose of the label; used for errors
-   * @throws EvalException if the label can't be parsed
-   */
-  private static ImmutableList<Label> parseLabels(
-      Iterable<String> inputs, StarlarkThread thread, String adjective) throws EvalException {
+  private static ImmutableList<Label> parseExecCompatibleWith(
+      Sequence<?> inputs, StarlarkThread thread) throws EvalException {
     ImmutableList.Builder<Label> parsedLabels = new ImmutableList.Builder<>();
     LabelConverter converter = LabelConverter.forThread(thread);
-    for (String input : inputs) {
+    for (String input : Sequence.cast(inputs, String.class, "exec_compatible_with")) {
       try {
         Label label = converter.convert(input);
         parsedLabels.add(label);
       } catch (LabelSyntaxException e) {
-        throw Starlark.errorf(
-            "Unable to parse %s label '%s': %s", adjective, input, e.getMessage());
+        throw Starlark.errorf("Unable to parse constraint label '%s': %s", input, e.getMessage());
       }
     }
     return parsedLabels.build();
   }
 
-  private static ImmutableList<Label> parseToolchains(Sequence<?> inputs, StarlarkThread thread)
-      throws EvalException {
-    return parseLabels(Sequence.cast(inputs, String.class, "toolchains"), thread, "toolchain");
-  }
-
-  private static ImmutableList<Label> parseExecCompatibleWith(
-      Sequence<?> inputs, StarlarkThread thread) throws EvalException {
-    return parseLabels(
-        Sequence.cast(inputs, String.class, "exec_compatible_with"), thread, "constraint");
-  }
-
   @Override
   public StarlarkAspect aspect(
       StarlarkFunction implementation,
@@ -698,7 +678,6 @@
           "An aspect cannot simultaneously have required providers and apply to generating rules.");
     }
 
-    ImmutableList<Label> toolchainTypes = parseToolchains(toolchains, thread);
     return new StarlarkDefinedAspect(
         implementation,
         attrAspects.build(),
@@ -712,9 +691,7 @@
         ImmutableSet.copyOf(Sequence.cast(fragments, String.class, "fragments")),
         HostTransition.INSTANCE,
         ImmutableSet.copyOf(Sequence.cast(hostFragments, String.class, "host_fragments")),
-        toolchainTypes.stream()
-            .map(tt -> ToolchainTypeRequirement.create(tt))
-            .collect(toImmutableSet()),
+        parseToolchainTypes(toolchains, thread),
         useToolchainTransition,
         applyToGeneratingRules);
   }
@@ -1049,13 +1026,62 @@
       return ExecGroup.copyFromDefault();
     }
 
-    ImmutableSet<Label> toolchainTypes = ImmutableSet.copyOf(parseToolchains(toolchains, thread));
+    ImmutableSet<ToolchainTypeRequirement> toolchainTypes = parseToolchainTypes(toolchains, thread);
     ImmutableSet<Label> constraints =
         ImmutableSet.copyOf(parseExecCompatibleWith(execCompatibleWith, thread));
     return ExecGroup.builder()
-        .requiredToolchains(toolchainTypes)
+        .toolchainTypes(toolchainTypes)
         .execCompatibleWith(constraints)
         .copyFrom(null)
         .build();
   }
+
+  private static ImmutableSet<ToolchainTypeRequirement> parseToolchainTypes(
+      Sequence<?> rawToolchains, StarlarkThread thread) throws EvalException {
+    Map<Label, ToolchainTypeRequirement> toolchainTypes = new HashMap<>();
+    LabelConverter converter = LabelConverter.forThread(thread);
+
+    for (Object rawToolchain : rawToolchains) {
+      ToolchainTypeRequirement toolchainType = parseToolchainType(converter, rawToolchain);
+      Label typeLabel = toolchainType.toolchainType();
+      ToolchainTypeRequirement previous = toolchainTypes.get(typeLabel);
+      if (previous != null) {
+        // Keep the one with the strictest requirements.
+        toolchainType = ToolchainTypeRequirement.strictest(previous, toolchainType);
+      }
+      toolchainTypes.put(typeLabel, toolchainType);
+    }
+
+    return ImmutableSet.copyOf(toolchainTypes.values());
+  }
+
+  private static ToolchainTypeRequirement parseToolchainType(
+      LabelConverter converter, Object rawToolchain) throws EvalException {
+    // Handle actual ToolchainTypeRequirement objects.
+    if (rawToolchain instanceof ToolchainTypeRequirement) {
+      return (ToolchainTypeRequirement) rawToolchain;
+    }
+
+    // Handle Label-like objects.
+    Label toolchainLabel = null;
+    if (rawToolchain instanceof Label) {
+      toolchainLabel = (Label) rawToolchain;
+    } else if (rawToolchain instanceof String) {
+      try {
+        toolchainLabel = converter.convert((String) rawToolchain);
+      } catch (LabelSyntaxException e) {
+        throw Starlark.errorf(
+            "Unable to parse toolchain_type label '%s': %s", rawToolchain, e.getMessage());
+      }
+    }
+
+    if (toolchainLabel != null) {
+      return ToolchainTypeRequirement.builder(toolchainLabel).mandatory(true).build();
+    }
+
+    // It's not a valid type.
+    throw Starlark.errorf(
+        "'toolchains' takes a toolchain_type, Label, or String, but instead got a %s",
+        rawToolchain.getClass().getSimpleName());
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/packages/ExecGroup.java b/src/main/java/com/google/devtools/build/lib/packages/ExecGroup.java
index e3ccd68..c8666ff 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/ExecGroup.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/ExecGroup.java
@@ -14,7 +14,6 @@
 
 package com.google.devtools.build.lib.packages;
 
-import static com.google.common.collect.ImmutableSet.toImmutableSet;
 
 import com.google.auto.value.AutoValue;
 import com.google.common.collect.ImmutableMap;
@@ -83,16 +82,6 @@
   @AutoValue.Builder
   public interface Builder {
 
-    /** Sets the required toolchain types. */
-    // TODO(katre): Remove this once all callers use toolchainTypes.
-    default Builder requiredToolchains(ImmutableSet<Label> toolchainTypes) {
-      ImmutableSet<ToolchainTypeRequirement> toolchainTypeRequirements =
-          toolchainTypes.stream()
-              .map(label -> ToolchainTypeRequirement.create(label))
-              .collect(toImmutableSet());
-      return this.toolchainTypes(toolchainTypeRequirements);
-    }
-
     /** Sets the toolchain type requirements. */
     default Builder toolchainTypes(ImmutableSet<ToolchainTypeRequirement> toolchainTypes) {
       toolchainTypes.forEach(this::addToolchainType);
diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
index 5f7e1fb..6715a30 100644
--- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
+++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
@@ -334,13 +334,14 @@
                     + "Starlark rules. This flag may be removed in the future."),
         @Param(
             name = TOOLCHAINS_PARAM,
-            allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
+            allowedTypes = {@ParamType(type = Sequence.class, generic1 = Object.class)},
             named = true,
             defaultValue = "[]",
             doc =
-                "If set, the set of toolchains this rule requires. Toolchains will be "
-                    + "found by checking the current platform, and provided to the rule "
-                    + "implementation via <code>ctx.toolchain</code>."),
+                "If set, the set of toolchains this rule requires. The list can contain String,"
+                    + " Label, or StarlarkToolchainTypeApi objects, in any combination. Toolchains"
+                    + " will be found by checking the current platform, and provided to the rule"
+                    + " implementation via <code>ctx.toolchain</code>."),
         @Param(
             name = "incompatible_use_toolchain_transition",
             defaultValue = "False",
@@ -610,13 +611,14 @@
                     + "in host configuration."),
         @Param(
             name = TOOLCHAINS_PARAM,
-            allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
+            allowedTypes = {@ParamType(type = Sequence.class, generic1 = Object.class)},
             named = true,
             defaultValue = "[]",
             doc =
-                "If set, the set of toolchains this rule requires. Toolchains will be "
-                    + "found by checking the current platform, and provided to the rule "
-                    + "implementation via <code>ctx.toolchain</code>."),
+                "If set, the set of toolchains this rule requires. The list can contain String,"
+                    + " Label, or StarlarkToolchainTypeApi objects, in any combination. Toolchains"
+                    + " will be found by checking the current platform, and provided to the rule"
+                    + " implementation via <code>ctx.toolchain</code>."),
         @Param(
             name = "incompatible_use_toolchain_transition",
             defaultValue = "False",
@@ -690,11 +692,13 @@
       parameters = {
         @Param(
             name = TOOLCHAINS_PARAM,
-            allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
+            allowedTypes = {@ParamType(type = Sequence.class, generic1 = Object.class)},
             named = true,
             positional = false,
             defaultValue = "[]",
-            doc = "The set of toolchains this execution group requires."),
+            doc =
+                "The set of toolchains this execution group requires. The list can contain String,"
+                    + " Label, or StarlarkToolchainTypeApi objects, in any combination."),
         @Param(
             name = EXEC_COMPATIBLE_WITH_PARAM,
             allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/testing/BUILD b/src/test/java/com/google/devtools/build/lib/analysis/testing/BUILD
index 720ed20..f04a3fd 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/testing/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/analysis/testing/BUILD
@@ -20,6 +20,7 @@
         "ExecGroupSubject.java",
         "ResolvedToolchainContextSubject.java",
         "RuleClassSubject.java",
+        "StarlarkDefinedAspectSubject.java",
         "ToolchainCollectionSubject.java",
         "ToolchainContextSubject.java",
         "ToolchainInfoSubject.java",
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/testing/StarlarkDefinedAspectSubject.java b/src/test/java/com/google/devtools/build/lib/analysis/testing/StarlarkDefinedAspectSubject.java
new file mode 100644
index 0000000..ee9b93c
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/analysis/testing/StarlarkDefinedAspectSubject.java
@@ -0,0 +1,86 @@
+// Copyright 2022 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.analysis.testing;
+
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
+import static com.google.common.truth.Truth.assertAbout;
+
+import com.google.common.base.Functions;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.truth.FailureMetadata;
+import com.google.common.truth.MapSubject;
+import com.google.common.truth.Subject;
+import com.google.devtools.build.lib.analysis.config.ToolchainTypeRequirement;
+import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.packages.StarlarkDefinedAspect;
+import java.util.Map;
+
+/** A Truth {@link Subject} for {@link StarlarkDefinedAspect}. */
+public class StarlarkDefinedAspectSubject extends Subject {
+  // Static data.
+
+  /** Entry point for test assertions related to {@link StarlarkDefinedAspect}. */
+  public static StarlarkDefinedAspectSubject assertThat(
+      StarlarkDefinedAspect starlarkDefinedAspect) {
+    return assertAbout(StarlarkDefinedAspectSubject::new).that(starlarkDefinedAspect);
+  }
+
+  /** Static method for getting the subject factory (for use with assertAbout()). */
+  public static Subject.Factory<StarlarkDefinedAspectSubject, StarlarkDefinedAspect>
+      starlarkDefinedAspects() {
+    return StarlarkDefinedAspectSubject::new;
+  }
+
+  // Instance fields.
+
+  private final StarlarkDefinedAspect actual;
+  private final Map<Label, ToolchainTypeRequirement> toolchainTypesMap;
+
+  protected StarlarkDefinedAspectSubject(
+      FailureMetadata failureMetadata, StarlarkDefinedAspect subject) {
+    super(failureMetadata, subject);
+    this.actual = subject;
+    this.toolchainTypesMap = makeToolchainTypesMap(subject);
+  }
+
+  private static ImmutableMap<Label, ToolchainTypeRequirement> makeToolchainTypesMap(
+      StarlarkDefinedAspect subject) {
+    return subject.getToolchainTypes().stream()
+        .collect(toImmutableMap(ToolchainTypeRequirement::toolchainType, Functions.identity()));
+  }
+
+  public MapSubject toolchainTypes() {
+    return check("getToolchainTypes()").that(toolchainTypesMap);
+  }
+
+  public ToolchainTypeRequirementSubject toolchainType(String toolchainTypeLabel) {
+    return toolchainType(Label.parseAbsoluteUnchecked(toolchainTypeLabel));
+  }
+
+  public ToolchainTypeRequirementSubject toolchainType(Label toolchainType) {
+    return check("toolchainType(%s)", toolchainType)
+        .about(ToolchainTypeRequirementSubject.toolchainTypeRequirements())
+        .that(toolchainTypesMap.get(toolchainType));
+  }
+
+  public void hasToolchainType(String toolchainTypeLabel) {
+    toolchainType(toolchainTypeLabel).isNotNull();
+  }
+
+  public void hasToolchainType(Label toolchainType) {
+    toolchainType(toolchainType).isNotNull();
+  }
+
+  // TODO(blaze-team): Add more useful methods.
+}
diff --git a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
index b264bd1..88d06f3 100644
--- a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
@@ -17,6 +17,7 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.lib.analysis.testing.ExecGroupSubject.assertThat;
 import static com.google.devtools.build.lib.analysis.testing.RuleClassSubject.assertThat;
+import static com.google.devtools.build.lib.analysis.testing.StarlarkDefinedAspectSubject.assertThat;
 import static org.junit.Assert.assertThrows;
 
 import com.google.common.base.Joiner;
@@ -25,7 +26,6 @@
 import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.analysis.ConfiguredRuleClassProvider;
 import com.google.devtools.build.lib.analysis.RuleContext;
-import com.google.devtools.build.lib.analysis.config.ToolchainTypeRequirement;
 import com.google.devtools.build.lib.analysis.config.transitions.NoTransition;
 import com.google.devtools.build.lib.analysis.starlark.StarlarkAttrModule;
 import com.google.devtools.build.lib.analysis.starlark.StarlarkRuleClassFunctions.StarlarkRuleFunction;
@@ -656,15 +656,26 @@
 
   @Test
   public void testAspectAddToolchain() throws Exception {
-    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
     evalAndExport(
-        ev, "def _impl(ctx): pass", "a1 = aspect(_impl, toolchains=['//test:my_toolchain_type'])");
+        ev,
+        "def _impl(ctx): pass",
+        "a1 = aspect(_impl,",
+        "    toolchains=[",
+        "        '//test:my_toolchain_type1',",
+        "        config_common.toolchain_type('//test:my_toolchain_type2'),",
+        "        config_common.toolchain_type('//test:my_toolchain_type3', mandatory=False),",
+        "        config_common.toolchain_type('//test:my_toolchain_type4', mandatory=True),",
+        "    ],",
+        ")");
     StarlarkDefinedAspect a = (StarlarkDefinedAspect) ev.lookup("a1");
-    // TODO(https://github.com/bazelbuild/bazel/issues/14726): Add tests of optional toolchains.
-    assertThat(a.getToolchainTypes())
-        .containsExactly(
-            ToolchainTypeRequirement.create(
-                Label.parseAbsoluteUnchecked("//test:my_toolchain_type")));
+    assertThat(a).hasToolchainType("//test:my_toolchain_type1");
+    assertThat(a).toolchainType("//test:my_toolchain_type1").isMandatory();
+    assertThat(a).hasToolchainType("//test:my_toolchain_type2");
+    assertThat(a).toolchainType("//test:my_toolchain_type2").isMandatory();
+    assertThat(a).hasToolchainType("//test:my_toolchain_type3");
+    assertThat(a).toolchainType("//test:my_toolchain_type3").isOptional();
+    assertThat(a).hasToolchainType("//test:my_toolchain_type4");
+    assertThat(a).toolchainType("//test:my_toolchain_type4").isMandatory();
   }
 
   @Test
@@ -2299,25 +2310,60 @@
 
   @Test
   public void testRuleAddToolchain() throws Exception {
-    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
     evalAndExport(
         ev,
         "def impl(ctx): return None",
-        "r1 = rule(impl, toolchains=['//test:my_toolchain_type'])");
-    // TODO(https://github.com/bazelbuild/bazel/issues/14726): Add tests of optional toolchains.
+        "r1 = rule(impl,",
+        "    toolchains=[",
+        "        '//test:my_toolchain_type1',",
+        "        config_common.toolchain_type('//test:my_toolchain_type2'),",
+        "        config_common.toolchain_type('//test:my_toolchain_type3', mandatory=False),",
+        "        config_common.toolchain_type('//test:my_toolchain_type4', mandatory=True),",
+        "    ],",
+        ")");
     RuleClass c = ((StarlarkRuleFunction) ev.lookup("r1")).getRuleClass();
-    assertThat(c).hasToolchainType("//test:my_toolchain_type");
+    assertThat(c).hasToolchainType("//test:my_toolchain_type1");
+    assertThat(c).toolchainType("//test:my_toolchain_type1").isMandatory();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type2");
+    assertThat(c).toolchainType("//test:my_toolchain_type2").isMandatory();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type3");
+    assertThat(c).toolchainType("//test:my_toolchain_type3").isOptional();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type4");
+    assertThat(c).toolchainType("//test:my_toolchain_type4").isMandatory();
+  }
+
+  @Test
+  public void testRuleAddToolchain_duplicate() throws Exception {
+    evalAndExport(
+        ev,
+        "def impl(ctx): return None",
+        "r1 = rule(impl,",
+        "    toolchains=[",
+        "        '//test:my_toolchain_type1',",
+        "        config_common.toolchain_type('//test:my_toolchain_type1'),",
+        "        config_common.toolchain_type('//test:my_toolchain_type2', mandatory = False),",
+        "        config_common.toolchain_type('//test:my_toolchain_type2', mandatory = True),",
+        "        config_common.toolchain_type('//test:my_toolchain_type3', mandatory = False),",
+        "        config_common.toolchain_type('//test:my_toolchain_type3', mandatory = False),",
+        "    ],",
+        ")");
+
+    RuleClass c = ((StarlarkRuleFunction) ev.lookup("r1")).getRuleClass();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type1");
+    assertThat(c).toolchainType("//test:my_toolchain_type1").isMandatory();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type2");
+    assertThat(c).toolchainType("//test:my_toolchain_type2").isMandatory();
+    assertThat(c).hasToolchainType("//test:my_toolchain_type3");
+    assertThat(c).toolchainType("//test:my_toolchain_type3").isOptional();
   }
 
   @Test
   public void testRuleAddExecutionConstraints() throws Exception {
     registerDummyStarlarkFunction();
-    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
     evalAndExport(
         ev,
         "r1 = rule(",
         "  implementation = impl,",
-        "  toolchains=['//test:my_toolchain_type'],",
         "  exec_compatible_with=['//constraint:cv1', '//constraint:cv2'],",
         ")");
     RuleClass c = ((StarlarkRuleFunction) ev.lookup("r1")).getRuleClass();
@@ -2330,28 +2376,37 @@
   @Test
   public void testRuleAddExecGroup() throws Exception {
     registerDummyStarlarkFunction();
-    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
     evalAndExport(
         ev,
         "plum = rule(",
         "  implementation = impl,",
         "  exec_groups = {",
         "    'group': exec_group(",
-        "      toolchains=['//test:my_toolchain_type'],",
+        "      toolchains=[",
+        "        '//test:my_toolchain_type1',",
+        "        config_common.toolchain_type('//test:my_toolchain_type2'),",
+        "        config_common.toolchain_type('//test:my_toolchain_type3', mandatory=False),",
+        "        config_common.toolchain_type('//test:my_toolchain_type4', mandatory=True),",
+        "      ],",
         "      exec_compatible_with=['//constraint:cv1', '//constraint:cv2'],",
         "    ),",
         "  },",
         ")");
     RuleClass plum = ((StarlarkRuleFunction) ev.lookup("plum")).getRuleClass();
     assertThat(plum.getToolchainTypes()).isEmpty();
-    // TODO(https://github.com/bazelbuild/bazel/issues/14726): Add tests of optional toolchains.
-    assertThat(plum.getExecGroups().get("group")).hasToolchainType("//test:my_toolchain_type");
-    assertThat(plum.getExecGroups().get("group"))
-        .toolchainType("//test:my_toolchain_type")
-        .isMandatory();
+    ExecGroup execGroup = plum.getExecGroups().get("group");
+    assertThat(execGroup).hasToolchainType("//test:my_toolchain_type1");
+    assertThat(execGroup).toolchainType("//test:my_toolchain_type1").isMandatory();
+    assertThat(execGroup).hasToolchainType("//test:my_toolchain_type2");
+    assertThat(execGroup).toolchainType("//test:my_toolchain_type2").isMandatory();
+    assertThat(execGroup).hasToolchainType("//test:my_toolchain_type3");
+    assertThat(execGroup).toolchainType("//test:my_toolchain_type3").isOptional();
+    assertThat(execGroup).hasToolchainType("//test:my_toolchain_type4");
+    assertThat(execGroup).toolchainType("//test:my_toolchain_type4").isMandatory();
+
     assertThat(plum.getExecutionPlatformConstraints()).isEmpty();
-    assertThat(plum.getExecGroups().get("group")).hasExecCompatibleWith("//constraint:cv1");
-    assertThat(plum.getExecGroups().get("group")).hasExecCompatibleWith("//constraint:cv2");
+    assertThat(execGroup).hasExecCompatibleWith("//constraint:cv1");
+    assertThat(execGroup).hasExecCompatibleWith("//constraint:cv2");
   }
 
   @Test
@@ -2390,17 +2445,27 @@
 
   @Test
   public void testCreateExecGroup() throws Exception {
-    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
     evalAndExport(
         ev,
         "group = exec_group(",
-        "  toolchains=['//test:my_toolchain_type'],",
+        "  toolchains=[",
+        "    '//test:my_toolchain_type1',",
+        "    config_common.toolchain_type('//test:my_toolchain_type2'),",
+        "    config_common.toolchain_type('//test:my_toolchain_type3', mandatory=False),",
+        "    config_common.toolchain_type('//test:my_toolchain_type4', mandatory=True),",
+        "  ],",
         "  exec_compatible_with=['//constraint:cv1', '//constraint:cv2'],",
         ")");
     ExecGroup group = ((ExecGroup) ev.lookup("group"));
-    // TODO(https://github.com/bazelbuild/bazel/issues/14726): Add tests of optional toolchains.
-    assertThat(group).hasToolchainType("//test:my_toolchain_type");
-    assertThat(group).toolchainType("//test:my_toolchain_type").isMandatory();
+    assertThat(group).hasToolchainType("//test:my_toolchain_type1");
+    assertThat(group).toolchainType("//test:my_toolchain_type1").isMandatory();
+    assertThat(group).hasToolchainType("//test:my_toolchain_type2");
+    assertThat(group).toolchainType("//test:my_toolchain_type2").isMandatory();
+    assertThat(group).hasToolchainType("//test:my_toolchain_type3");
+    assertThat(group).toolchainType("//test:my_toolchain_type3").isOptional();
+    assertThat(group).hasToolchainType("//test:my_toolchain_type4");
+    assertThat(group).toolchainType("//test:my_toolchain_type4").isMandatory();
+
     assertThat(group).hasExecCompatibleWith("//constraint:cv1");
     assertThat(group).hasExecCompatibleWith("//constraint:cv2");
   }
diff --git a/src/test/java/com/google/devtools/build/lib/starlark/util/BUILD b/src/test/java/com/google/devtools/build/lib/starlark/util/BUILD
index 5a36fc3..017a907b 100644
--- a/src/test/java/com/google/devtools/build/lib/starlark/util/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/starlark/util/BUILD
@@ -24,6 +24,7 @@
         "//src/main/java/com/google/devtools/build/lib/packages",
         "//src/main/java/com/google/devtools/build/lib/packages/semantics",
         "//src/main/java/com/google/devtools/build/lib/pkgcache",
+        "//src/main/java/com/google/devtools/build/lib/rules/config",
         "//src/main/java/com/google/devtools/build/lib/rules/platform",
         "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
         "//src/main/java/com/google/devtools/common/options",
diff --git a/src/test/java/com/google/devtools/build/lib/starlark/util/BazelEvaluationTestCase.java b/src/test/java/com/google/devtools/build/lib/starlark/util/BazelEvaluationTestCase.java
index 95b7cfb..361135a 100644
--- a/src/test/java/com/google/devtools/build/lib/starlark/util/BazelEvaluationTestCase.java
+++ b/src/test/java/com/google/devtools/build/lib/starlark/util/BazelEvaluationTestCase.java
@@ -29,6 +29,7 @@
 import com.google.devtools.build.lib.packages.BazelStarlarkContext;
 import com.google.devtools.build.lib.packages.SymbolGenerator;
 import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
+import com.google.devtools.build.lib.rules.config.ConfigStarlarkCommon;
 import com.google.devtools.build.lib.rules.platform.PlatformCommon;
 import com.google.devtools.build.lib.testutil.TestConstants;
 import com.google.devtools.common.options.Options;
@@ -138,6 +139,7 @@
   private static Object newModule(ImmutableMap.Builder<String, Object> predeclared) {
     StarlarkModules.addPredeclared(predeclared);
     predeclared.put("platform_common", new PlatformCommon());
+    predeclared.put("config_common", new ConfigStarlarkCommon());
 
     // Return the module's client data. (This one uses dummy values for tests.)
     return BazelModuleContext.create(