Update ToolchainResolutionFunction to handle optional toolchain types

Part of Optional Toolchains (#14726).

Closes #15357.

PiperOrigin-RevId: 445399501
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
index e2cdeef..43b52ee 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
@@ -161,6 +161,10 @@
       return new AutoValue_ToolchainResolutionFunction_ToolchainType(
           toolchainTypeRequirement, toolchainTypeInfo);
     }
+
+    public boolean mandatory() {
+      return toolchainTypeRequirement().mandatory();
+    }
   }
 
   /**
@@ -428,44 +432,40 @@
     // Determine the potential set of toolchains.
     Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains =
         HashBasedTable.create();
-    ImmutableSet.Builder<ToolchainTypeInfo> requiredToolchainTypesBuilder = ImmutableSet.builder();
-    List<Label> missingToolchains = new ArrayList<>();
+    List<Label> missingMandatoryToolchains = new ArrayList<>();
     for (SingleToolchainResolutionKey key : registeredToolchainKeys) {
       SingleToolchainResolutionValue singleToolchainResolutionValue =
           (SingleToolchainResolutionValue)
               results.getOrThrow(key, InvalidToolchainLabelException.class);
-        if (singleToolchainResolutionValue == null) {
-          valuesMissing = true;
-          continue;
-        }
+      if (singleToolchainResolutionValue == null) {
+        valuesMissing = true;
+        continue;
+      }
 
-      if (singleToolchainResolutionValue.availableToolchainLabels().isEmpty()) {
-        // Save the missing type and continue looping to check for more.
-        // TODO(katre): Handle mandatory/optional.
-        missingToolchains.add(key.toolchainType().toolchainType());
-      } else {
+      if (!singleToolchainResolutionValue.availableToolchainLabels().isEmpty()) {
         ToolchainTypeInfo requiredToolchainType = singleToolchainResolutionValue.toolchainType();
-        requiredToolchainTypesBuilder.add(requiredToolchainType);
         resolvedToolchains.putAll(
             findPlatformsAndLabels(requiredToolchainType, singleToolchainResolutionValue));
+      } else if (key.toolchainType().mandatory()) {
+        // Save the missing type and continue looping to check for more.
+        missingMandatoryToolchains.add(key.toolchainType().toolchainType());
       }
+      // TODO(katre): track missing optional toolchains?
     }
 
-    if (!missingToolchains.isEmpty()) {
-      throw new UnresolvedToolchainsException(missingToolchains);
+    // Verify that all mandatory toolchain types have a toolchain.
+    if (!missingMandatoryToolchains.isEmpty()) {
+      throw new UnresolvedToolchainsException(missingMandatoryToolchains);
     }
 
     if (valuesMissing) {
       throw new ValueMissingException();
     }
 
-    ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes = requiredToolchainTypesBuilder.build();
-
-    // Find and return the first execution platform which has all required toolchains.
-    // TODO(katre): Handle mandatory/optional.
+    // Find and return the first execution platform which has all mandatory toolchains.
     Optional<ConfiguredTargetKey> selectedExecutionPlatformKey =
         findExecutionPlatformForToolchains(
-            requiredToolchainTypes,
+            toolchainTypes,
             forcedExecutionPlatform,
             platformKeys.executionPlatformKeys(),
             resolvedToolchains);
@@ -520,43 +520,43 @@
    * resolvedToolchains} and has all required toolchain types.
    */
   private static Optional<ConfiguredTargetKey> findExecutionPlatformForToolchains(
-      ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes,
+      ImmutableSet<ToolchainType> toolchainTypes,
       Optional<ConfiguredTargetKey> forcedExecutionPlatform,
       ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys,
       Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
 
     if (forcedExecutionPlatform.isPresent()) {
       // Is the forced platform suitable?
-      if (isPlatformSuitable(
-          forcedExecutionPlatform.get(), requiredToolchainTypes, resolvedToolchains)) {
+      if (isPlatformSuitable(forcedExecutionPlatform.get(), toolchainTypes, resolvedToolchains)) {
         return forcedExecutionPlatform;
       }
     }
 
+    // Choose the first execution platform that has all mandatory toolchains.
     return availableExecutionPlatformKeys.stream()
-        .filter(epk -> isPlatformSuitable(epk, requiredToolchainTypes, resolvedToolchains))
+        .filter(epk -> isPlatformSuitable(epk, toolchainTypes, resolvedToolchains))
         .findFirst();
   }
 
   private static boolean isPlatformSuitable(
       ConfiguredTargetKey executionPlatformKey,
-      ImmutableSet<ToolchainTypeInfo> requiredToolchainTypes,
+      ImmutableSet<ToolchainType> toolchainTypes,
       Table<ConfiguredTargetKey, ToolchainTypeInfo, Label> resolvedToolchains) {
-    if (requiredToolchainTypes.isEmpty()) {
+    if (toolchainTypes.isEmpty()) {
       // Since there aren't any toolchains, we should be able to use any execution platform that
       // has made it this far.
       return true;
     }
 
-    if (!resolvedToolchains.containsRow(executionPlatformKey)) {
-      return false;
-    }
-
-    // Unless all toolchains are present, ignore this execution platform.
+    // Determine whether all mandatory toolchains are present.
     return resolvedToolchains
         .row(executionPlatformKey)
         .keySet()
-        .containsAll(requiredToolchainTypes);
+        .containsAll(
+            toolchainTypes.stream()
+                .filter(ToolchainType::mandatory)
+                .map(ToolchainType::toolchainTypeInfo)
+                .collect(toImmutableSet()));
   }
 
   private static final class ValueMissingException extends Exception {
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
index e6b7981..6a84433 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
@@ -19,6 +19,7 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.analysis.config.ToolchainTypeRequirement;
+import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.rules.platform.ToolchainTestCase;
 import com.google.devtools.build.lib.skyframe.ConstraintValueLookupUtil.InvalidConstraintValueException;
@@ -86,6 +87,92 @@
     assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
   }
 
+  // TODO(katre): Add further tests for optional/mandatory/mixed toolchains.
+
+  @Test
+  public void resolve_optional() throws Exception {
+    // This should select platform mac, toolchain extra_toolchain_mac, because platform
+    // mac is listed first.
+    addToolchain(
+        "extra",
+        "extra_toolchain_linux",
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    addToolchain(
+        "extra",
+        "extra_toolchain_mac",
+        ImmutableList.of("//constraints:mac"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    rewriteWorkspace(
+        "register_toolchains('//extra:extra_toolchain_linux', '//extra:extra_toolchain_mac')",
+        "register_execution_platforms('//platforms:mac', '//platforms:linux')");
+
+    useConfiguration("--platforms=//platforms:linux");
+    ToolchainContextKey key =
+        ToolchainContextKey.key()
+            .configurationKey(targetConfigKey)
+            .toolchainTypes(testToolchainType)
+            .build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext).hasToolchainType(testToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//extra:extra_toolchain_mac_impl");
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//platforms:mac");
+    assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
+  }
+
+  @Test
+  public void resolve_multiple() throws Exception {
+    Label secondToolchainTypeLabel = Label.parseAbsoluteUnchecked("//second:toolchain_type");
+    ToolchainTypeRequirement secondToolchainTypeRequirement =
+        ToolchainTypeRequirement.create(secondToolchainTypeLabel);
+    ToolchainTypeInfo secondToolchainTypeInfo = ToolchainTypeInfo.create(secondToolchainTypeLabel);
+    scratch.file("second/BUILD", "toolchain_type(name = 'toolchain_type')");
+
+    addToolchain(
+        "main",
+        "main_toolchain_linux",
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    addToolchain(
+        "main",
+        "second_toolchain_linux",
+        secondToolchainTypeLabel,
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    rewriteWorkspace(
+        "register_toolchains('//main:all',)", "register_execution_platforms('//platforms:linux')");
+
+    useConfiguration("--platforms=//platforms:linux");
+    ToolchainContextKey key =
+        ToolchainContextKey.key()
+            .configurationKey(targetConfigKey)
+            .toolchainTypes(testToolchainType, secondToolchainTypeRequirement)
+            .build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext).hasToolchainType(testToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//main:main_toolchain_linux_impl");
+    assertThat(unloadedToolchainContext).hasToolchainType(secondToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//main:second_toolchain_linux_impl");
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//platforms:linux");
+    assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
+  }
+
   @Test
   public void resolve_mandatory_missing() throws Exception {
     // There is no toolchain for the requested type.
@@ -106,6 +193,91 @@
   }
 
   @Test
+  public void resolve_multiple_optional() throws Exception {
+    Label secondToolchainTypeLabel = Label.parseAbsoluteUnchecked("//second:toolchain_type");
+    ToolchainTypeRequirement secondToolchainTypeRequirement =
+        ToolchainTypeRequirement.builder(secondToolchainTypeLabel).mandatory(false).build();
+    ToolchainTypeInfo secondToolchainTypeInfo = ToolchainTypeInfo.create(secondToolchainTypeLabel);
+    scratch.file("second/BUILD", "toolchain_type(name = 'toolchain_type')");
+
+    addToolchain(
+        "main",
+        "main_toolchain_linux",
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    addToolchain(
+        "main",
+        "second_toolchain_linux",
+        secondToolchainTypeLabel,
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    rewriteWorkspace(
+        "register_toolchains('//main:all',)", "register_execution_platforms('//platforms:linux')");
+
+    useConfiguration("--platforms=//platforms:linux");
+    ToolchainContextKey key =
+        ToolchainContextKey.key()
+            .configurationKey(targetConfigKey)
+            .toolchainTypes(testToolchainType, secondToolchainTypeRequirement)
+            .build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext).hasToolchainType(testToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//main:main_toolchain_linux_impl");
+    assertThat(unloadedToolchainContext).hasToolchainType(secondToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//main:second_toolchain_linux_impl");
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//platforms:linux");
+    assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
+  }
+
+  @Test
+  public void resolve_multiple_optional_missing() throws Exception {
+    Label secondToolchainTypeLabel = Label.parseAbsoluteUnchecked("//second:toolchain_type");
+    ToolchainTypeRequirement secondToolchainTypeRequirement =
+        ToolchainTypeRequirement.builder(secondToolchainTypeLabel).mandatory(false).build();
+    ToolchainTypeInfo secondToolchainTypeInfo = ToolchainTypeInfo.create(secondToolchainTypeLabel);
+    scratch.file("second/BUILD", "toolchain_type(name = 'toolchain_type')");
+
+    addToolchain(
+        "main",
+        "main_toolchain_linux",
+        ImmutableList.of("//constraints:linux"),
+        ImmutableList.of("//constraints:linux"),
+        "baz");
+    rewriteWorkspace(
+        "register_toolchains('//main:all',)", "register_execution_platforms('//platforms:linux')");
+
+    useConfiguration("--platforms=//platforms:linux");
+    ToolchainContextKey key =
+        ToolchainContextKey.key()
+            .configurationKey(targetConfigKey)
+            .toolchainTypes(testToolchainType, secondToolchainTypeRequirement)
+            .build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext).hasToolchainType(testToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).hasResolvedToolchain("//main:main_toolchain_linux_impl");
+    assertThat(unloadedToolchainContext).hasToolchainType(secondToolchainTypeLabel);
+    assertThat(unloadedToolchainContext)
+        .resolvedToolchainLabels()
+        .doesNotContain(Label.parseAbsoluteUnchecked("//main:second_toolchain_linux_impl"));
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//platforms:linux");
+    assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
+  }
+
+  @Test
   public void resolve_toolchainTypeAlias() throws Exception {
     addToolchain(
         "extra",
@@ -221,6 +393,29 @@
   }
 
   @Test
+  public void resolve_optional_unavailableToolchainType_single() throws Exception {
+    reporter.removeHandler(failFastHandler);
+    scratch.file("fake/toolchain/BUILD", "");
+    useConfiguration("--host_platform=//platforms:linux", "--platforms=//platforms:linux");
+    ToolchainContextKey key =
+        ToolchainContextKey.key()
+            .configurationKey(targetConfigKey)
+            .toolchainTypes(optionalToolchainType)
+            .build();
+
+    EvaluationResult<UnloadedToolchainContext> result = invokeToolchainResolution(key);
+
+    assertThatEvaluationResult(result).hasNoError();
+    UnloadedToolchainContext unloadedToolchainContext = result.get(key);
+    assertThat(unloadedToolchainContext).isNotNull();
+
+    assertThat(unloadedToolchainContext).hasToolchainType(optionalToolchainTypeLabel);
+    assertThat(unloadedToolchainContext).resolvedToolchainLabels().isEmpty();
+    assertThat(unloadedToolchainContext).hasExecutionPlatform("//platforms:linux");
+    assertThat(unloadedToolchainContext).hasTargetPlatform("//platforms:linux");
+  }
+
+  @Test
   public void resolve_unavailableToolchainType_multiple() throws Exception {
     reporter.removeHandler(failFastHandler);
     scratch.file("fake/toolchain/BUILD", "");
diff --git a/src/test/shell/integration/toolchain_test.sh b/src/test/shell/integration/toolchain_test.sh
index 3585045..24f3307 100755
--- a/src/test/shell/integration/toolchain_test.sh
+++ b/src/test/shell/integration/toolchain_test.sh
@@ -150,6 +150,8 @@
 function write_register_toolchain() {
   local pkg="${1}"
   local toolchain_name="${2:-test_toolchain}"
+  local exec_compatible_with="${3:-"[]"}"
+  local target_compatible_with="${4:-"[]"}"
 
   cat >> WORKSPACE <<EOF
 register_toolchains('//register/${pkg}:${toolchain_name}_1')
@@ -171,8 +173,8 @@
 toolchain(
     name = '${toolchain_name}_1',
     toolchain_type = '//${pkg}/toolchain:${toolchain_name}',
-    exec_compatible_with = [],
-    target_compatible_with = [],
+    exec_compatible_with = ${exec_compatible_with},
+    target_compatible_with = ${target_compatible_with},
     toolchain = ':${toolchain_name}_impl_1',
     visibility = ['//visibility:public'])
 EOF
@@ -484,6 +486,50 @@
   expect_log 'Using toolchain: rule message: "this is the rule", toolchain 1 extra_str: "foo from test_toolchain_1", toolchain 2 extra_str: "foo from test_toolchain_2"'
 }
 
+function test_multiple_toolchain_use_in_rule_with_optional_missing {
+  local -r pkg="${FUNCNAME[0]}"
+  write_test_toolchain "${pkg}" test_toolchain_1
+  write_test_toolchain "${pkg}" test_toolchain_2
+
+  write_register_toolchain "${pkg}" test_toolchain_1
+
+  # The rule uses two separate toolchains.
+  mkdir -p "${pkg}/toolchain"
+  cat > "${pkg}/toolchain/rule_use_toolchains.bzl" <<EOF
+def _impl(ctx):
+  toolchain_1 = ctx.toolchains['//${pkg}/toolchain:test_toolchain_1']
+  toolchain_2 = ctx.toolchains['//${pkg}/toolchain:test_toolchain_2']
+  message = ctx.attr.message
+  print(
+      'Using toolchain: rule message: "%s", toolchain 1 extra_str: "%s", toolchain 2 is none: %s' %
+         (message, toolchain_1.extra_str, toolchain_2 == None))
+  return []
+
+use_toolchains = rule(
+    implementation = _impl,
+    attrs = {
+        'message': attr.string(),
+    },
+    toolchains = [
+        '//${pkg}/toolchain:test_toolchain_1',
+        config_common.toolchain_type('//${pkg}/toolchain:test_toolchain_2', mandatory = False),
+    ],
+)
+EOF
+
+  mkdir -p "${pkg}/demo"
+  cat > "${pkg}/demo/BUILD" <<EOF
+load('//${pkg}/toolchain:rule_use_toolchains.bzl', 'use_toolchains')
+# Use the toolchain.
+use_toolchains(
+    name = 'use',
+    message = 'this is the rule')
+EOF
+
+  bazel build "//${pkg}/demo:use" &> $TEST_log || fail "Build failed"
+  expect_log 'Using toolchain: rule message: "this is the rule", toolchain 1 extra_str: "foo from test_toolchain_1", toolchain 2 is none: True'
+}
+
 function test_multiple_toolchain_use_in_rule_one_missing {
   local -r pkg="${FUNCNAME[0]}"
   write_test_toolchain "${pkg}" test_toolchain_1
@@ -2377,6 +2423,146 @@
   expect_log "foo_tool = <target @rules_foo//foo_tools:foo_tool>"
 }
 
+function test_exec_platform_order_with_mandatory_toolchains {
+  local -r pkg="${FUNCNAME[0]}"
+
+  # Add two possible execution platforms.
+  mkdir -p "${pkg}/platforms"
+  cat > "${pkg}/platforms/BUILD" <<EOF
+package(default_visibility = ['//visibility:public'])
+constraint_setting(name = 'setting')
+constraint_value(name = 'value1', constraint_setting = ':setting')
+constraint_value(name = 'value2', constraint_setting = ':setting')
+
+platform(
+    name = 'platform1',
+    constraint_values = [':value1'],
+    visibility = ['//visibility:public'])
+platform(
+    name = 'platform2',
+    constraint_values = [':value2'],
+    visibility = ['//visibility:public'])
+EOF
+  # Register them in order.
+  cat >> WORKSPACE <<EOF
+register_execution_platforms("//${pkg}/platforms:platform1", "//${pkg}/platforms:platform2")
+EOF
+
+  # Create a toolchain that only works with platform2
+  write_test_toolchain "${pkg}" test_toolchain
+  write_register_toolchain "${pkg}" test_toolchain "['//${pkg}/platforms:value2']"
+
+  # The rule must receive the toolchain.
+  mkdir -p "${pkg}/toolchain"
+  cat > "${pkg}/toolchain/rule_use_toolchains.bzl" <<EOF
+def _impl(ctx):
+  toolchain = ctx.toolchains['//${pkg}/toolchain:test_toolchain']
+  message = ctx.attr.message
+  print(
+      'Using toolchain: rule message: "%s", toolchain is none: %s' %
+         (message, toolchain == None))
+  return []
+
+use_toolchains = rule(
+    implementation = _impl,
+    attrs = {
+        'message': attr.string(),
+    },
+    toolchains = [
+        config_common.toolchain_type('//${pkg}/toolchain:test_toolchain', mandatory = True),
+    ],
+)
+EOF
+
+  mkdir -p "${pkg}/demo"
+  cat > "${pkg}/demo/BUILD" <<EOF
+load('//${pkg}/toolchain:rule_use_toolchains.bzl', 'use_toolchains')
+# Use the toolchain.
+use_toolchains(
+    name = 'use',
+    message = 'this is the rule')
+EOF
+
+  bazel build "//${pkg}/demo:use" &> $TEST_log || fail "Build failed"
+  bazel build \
+    --toolchain_resolution_debug=.* \
+    "//${pkg}/demo:use" &> $TEST_log || fail "Build failed"
+  # Verify that a toolchain was provided.
+  expect_log 'Using toolchain: rule message: "this is the rule", toolchain is none: False'
+  # Verify that the exec platform is platform2.
+  expect_log "Selected execution platform //${pkg}/platforms:platform2"
+}
+
+function test_exec_platform_order_with_optional_toolchains {
+  local -r pkg="${FUNCNAME[0]}"
+
+  # Add two possible execution platforms.
+  mkdir -p "${pkg}/platforms"
+  cat > "${pkg}/platforms/BUILD" <<EOF
+package(default_visibility = ['//visibility:public'])
+constraint_setting(name = 'setting')
+constraint_value(name = 'value1', constraint_setting = ':setting')
+constraint_value(name = 'value2', constraint_setting = ':setting')
+
+platform(
+    name = 'platform1',
+    constraint_values = [':value1'],
+    visibility = ['//visibility:public'])
+platform(
+    name = 'platform2',
+    constraint_values = [':value2'],
+    visibility = ['//visibility:public'])
+EOF
+  # Register them in order.
+  cat >> WORKSPACE <<EOF
+register_execution_platforms("//${pkg}/platforms:platform1", "//${pkg}/platforms:platform2")
+EOF
+
+  # Create a toolchain that only works with platform2
+  write_test_toolchain "${pkg}" test_toolchain
+  write_register_toolchain "${pkg}" test_toolchain "['//${pkg}/platforms:value2']"
+
+  # The rule can optionally use the toolchain.
+  mkdir -p "${pkg}/toolchain"
+  cat > "${pkg}/toolchain/rule_use_toolchains.bzl" <<EOF
+def _impl(ctx):
+  toolchain = ctx.toolchains['//${pkg}/toolchain:test_toolchain']
+  message = ctx.attr.message
+  print(
+      'Using toolchain: rule message: "%s", toolchain is none: %s' %
+         (message, toolchain == None))
+  return []
+
+use_toolchains = rule(
+    implementation = _impl,
+    attrs = {
+        'message': attr.string(),
+    },
+    toolchains = [
+        config_common.toolchain_type('//${pkg}/toolchain:test_toolchain', mandatory = False),
+    ],
+)
+EOF
+
+  mkdir -p "${pkg}/demo"
+  cat > "${pkg}/demo/BUILD" <<EOF
+load('//${pkg}/toolchain:rule_use_toolchains.bzl', 'use_toolchains')
+# Use the toolchain.
+use_toolchains(
+    name = 'use',
+    message = 'this is the rule')
+EOF
+
+  bazel build "//${pkg}/demo:use" &> $TEST_log || fail "Build failed"
+  bazel build \
+    --toolchain_resolution_debug=.* \
+    "//${pkg}/demo:use" &> $TEST_log || fail "Build failed"
+  # Verify that no toolchain was provided.
+  expect_log 'Using toolchain: rule message: "this is the rule", toolchain is none: True'
+  # Verify that the exec platform is platform1.
+  expect_log "Selected execution platform //${pkg}/platforms:platform1"
+}
+
 # TODO(katre): Test using toolchain-provided make variables from a genrule.
 
 run_suite "toolchain tests"