[7.0.0] Give `WORKSPACE` toolchains and platforms precedence over non-root mo… (#20430)

…dules

RELNOTES[INC]: Toolchains and execution platforms are now registered in
the following order with `--enable_bzlmod`:
1. root module's module file
2. `WORKSPACE` or `WORKSPACE.bzlmod`
3. non-root modules' module files
4. default toolchains registered by Bazel (does not apply with
`WORKSPACE.bzlmod` or execution platforms)

Fixes #20354

Closes #20407.

Commit
https://github.com/bazelbuild/bazel/commit/96b361205ee05dcacdcf5055ca9cc3e5ca5d126c#diff-a8d3aed419e661d4dbecb2dc6668444212d7b1707ff61330b7d8aae61e75d4df

PiperOrigin-RevId: 587826082
Change-Id: Ia98da6ef07b2fbf589ef369d986af2323af6f72a

Co-authored-by: Fabian Meumertzheim <fabian@meumertzhe.im>
Co-authored-by: Xudong Yang <wyv@google.com>
diff --git a/site/en/external/migration.md b/site/en/external/migration.md
index efd1cd1..c4eb0c4 100644
--- a/site/en/external/migration.md
+++ b/site/en/external/migration.md
@@ -453,6 +453,19 @@
     register_toolchains("@local_config_sh//:local_sh_toolchain")
     ```
 
+The toolchains and execution platforms registered in `WORKSPACE`,
+`WORKSPACE.bzlmod` and each Bazel module's `MODULE.bazel` file follow this
+order of precedence during toolchain selection (from highest to lowest):
+
+1. toolchains and execution platforms registered in the root module's
+   `MODULE.bazel` file.
+2. toolchains and execution platforms registered in the `WORKSPACE` or
+   `WORKSPACE.bzlmod` file.
+3. toolchains and execution platforms registered by modules that are
+   (transitive) dependencies of the root module.
+4. when not using `WORKSPACE.bzlmod`: toolchains registered in the `WORKSPACE`
+   [suffix](/external/migration#builtin-default-deps).
+
 [register_execution_platforms]: /rules/lib/globals/module#register_execution_platforms
 
 ### Introduce local repositories {:#introduce-local-deps}
diff --git a/src/main/java/com/google/devtools/build/lib/packages/Package.java b/src/main/java/com/google/devtools/build/lib/packages/Package.java
index 742c237..09dfa24 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/Package.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/Package.java
@@ -72,6 +72,7 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.OptionalInt;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.concurrent.Semaphore;
@@ -239,7 +240,7 @@
 
   private ImmutableList<TargetPattern> registeredExecutionPlatforms;
   private ImmutableList<TargetPattern> registeredToolchains;
-
+  private OptionalInt firstWorkspaceSuffixRegisteredToolchain;
   private long computationSteps;
 
   // These two fields are mutually exclusive. Which one is set depends on
@@ -402,6 +403,7 @@
     this.failureDetail = builder.getFailureDetail();
     this.registeredExecutionPlatforms = ImmutableList.copyOf(builder.registeredExecutionPlatforms);
     this.registeredToolchains = ImmutableList.copyOf(builder.registeredToolchains);
+    this.firstWorkspaceSuffixRegisteredToolchain = builder.firstWorkspaceSuffixRegisteredToolchain;
     this.repositoryMapping = Preconditions.checkNotNull(builder.repositoryMapping);
     this.mainRepositoryMapping = Preconditions.checkNotNull(builder.mainRepositoryMapping);
     ImmutableMap.Builder<RepositoryName, ImmutableMap<String, RepositoryName>>
@@ -722,6 +724,23 @@
     return registeredToolchains;
   }
 
+  public ImmutableList<TargetPattern> getUserRegisteredToolchains() {
+    return getRegisteredToolchains()
+        .subList(
+            0, firstWorkspaceSuffixRegisteredToolchain.orElse(getRegisteredToolchains().size()));
+  }
+
+  public ImmutableList<TargetPattern> getWorkspaceSuffixRegisteredToolchains() {
+    return getRegisteredToolchains()
+        .subList(
+            firstWorkspaceSuffixRegisteredToolchain.orElse(getRegisteredToolchains().size()),
+            getRegisteredToolchains().size());
+  }
+
+  OptionalInt getFirstWorkspaceSuffixRegisteredToolchain() {
+    return firstWorkspaceSuffixRegisteredToolchain;
+  }
+
   @Override
   public String toString() {
     return "Package("
@@ -936,8 +955,16 @@
     private final List<TargetPattern> registeredToolchains = new ArrayList<>();
 
     /**
-     * True iff the "package" function has already been called in this package.
+     * Tracks the index within {@link #registeredToolchains} of the first toolchain registered from
+     * the WORKSPACE suffixes rather than the WORKSPACE file (if any).
+     *
+     * <p>This is needed to distinguish between these toolchains during resolution: toolchains
+     * registered in WORKSPACE have precedence over those defined in non-root Bazel modules, which
+     * in turn have precedence over those from the WORKSPACE suffixes.
      */
+    private OptionalInt firstWorkspaceSuffixRegisteredToolchain = OptionalInt.empty();
+
+    /** True iff the "package" function has already been called in this package. */
     private boolean packageFunctionUsed;
 
     /**
@@ -1620,10 +1647,18 @@
       this.registeredExecutionPlatforms.addAll(platforms);
     }
 
-    void addRegisteredToolchains(List<TargetPattern> toolchains) {
+    void addRegisteredToolchains(List<TargetPattern> toolchains, boolean forWorkspaceSuffix) {
+      if (forWorkspaceSuffix && firstWorkspaceSuffixRegisteredToolchain.isEmpty()) {
+        firstWorkspaceSuffixRegisteredToolchain = OptionalInt.of(registeredToolchains.size());
+      }
       this.registeredToolchains.addAll(toolchains);
     }
 
+    void setFirstWorkspaceSuffixRegisteredToolchain(
+        OptionalInt firstWorkspaceSuffixRegisteredToolchain) {
+      this.firstWorkspaceSuffixRegisteredToolchain = firstWorkspaceSuffixRegisteredToolchain;
+    }
+
     @CanIgnoreReturnValue
     private Builder beforeBuild(boolean discoverAssumedInputFiles) throws NoSuchPackageException {
       Preconditions.checkNotNull(pkg);
diff --git a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
index d590119..095fd30 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactory.java
@@ -185,7 +185,10 @@
       builder.setFailureDetailOverride(aPackage.getFailureDetail());
     }
     builder.addRegisteredExecutionPlatforms(aPackage.getRegisteredExecutionPlatforms());
-    builder.addRegisteredToolchains(aPackage.getRegisteredToolchains());
+    builder.addRegisteredToolchains(
+        aPackage.getRegisteredToolchains(), /* forWorkspaceSuffix= */ false);
+    builder.setFirstWorkspaceSuffixRegisteredToolchain(
+        aPackage.getFirstWorkspaceSuffixRegisteredToolchain());
     builder.addRepositoryMappings(aPackage);
     for (Rule rule : aPackage.getTargets(Rule.class)) {
       try {
diff --git a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactoryHelper.java b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactoryHelper.java
index 1f514a4..0f28474 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactoryHelper.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceFactoryHelper.java
@@ -36,6 +36,13 @@
 /** A helper for the {@link WorkspaceFactory} to create repository rules */
 public final class WorkspaceFactoryHelper {
 
+  public static final String DEFAULT_WORKSPACE_SUFFIX_FILE = "/DEFAULT.WORKSPACE.SUFFIX";
+
+  public static boolean originatesInWorkspaceSuffix(
+      ImmutableList<StarlarkThread.CallStackEntry> callstack) {
+    return callstack.get(0).location.file().equals(DEFAULT_WORKSPACE_SUFFIX_FILE);
+  }
+
   @CanIgnoreReturnValue
   public static Rule createAndAddRepositoryRule(
       Package.Builder pkg,
@@ -70,7 +77,7 @@
         throw new LabelSyntaxException(e.getMessage());
       }
     }
-    pkg.addRegisteredToolchains(toolchains.build());
+    pkg.addRegisteredToolchains(toolchains.build(), originatesInWorkspaceSuffix(callstack));
     return rule;
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceGlobals.java b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceGlobals.java
index 6125563..1ef44d4 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/WorkspaceGlobals.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/WorkspaceGlobals.java
@@ -14,6 +14,7 @@
 
 package com.google.devtools.build.lib.packages;
 
+import static com.google.devtools.build.lib.packages.WorkspaceFactoryHelper.originatesInWorkspaceSuffix;
 import static net.starlark.java.eval.Starlark.NONE;
 
 import com.google.common.collect.ImmutableList;
@@ -114,7 +115,9 @@
     // Add to the package definition for later.
     Package.Builder builder = PackageFactory.getContext(thread).pkgBuilder;
     List<String> patterns = Sequence.cast(toolchainLabels, String.class, "toolchain_labels");
-    builder.addRegisteredToolchains(parsePatterns(patterns, builder, thread));
+    builder.addRegisteredToolchains(
+        parsePatterns(patterns, builder, thread),
+        originatesInWorkspaceSuffix(thread.getCallStack()));
   }
 
   @Override
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/WorkspaceFileFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/WorkspaceFileFunction.java
index 8ae2ad8..6d61865 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/WorkspaceFileFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/WorkspaceFileFunction.java
@@ -14,6 +14,7 @@
 
 package com.google.devtools.build.lib.skyframe;
 
+import static com.google.devtools.build.lib.packages.WorkspaceFactoryHelper.DEFAULT_WORKSPACE_SUFFIX_FILE;
 import static com.google.devtools.build.lib.rules.repository.ResolvedFileValue.ATTRIBUTES;
 import static com.google.devtools.build.lib.rules.repository.ResolvedFileValue.NATIVE;
 import static com.google.devtools.build.lib.rules.repository.ResolvedFileValue.REPOSITORIES;
@@ -220,7 +221,7 @@
       StarlarkFile file =
           StarlarkFile.parse(
               ParserInput.fromString(
-                  ruleClassProvider.getDefaultWorkspaceSuffix(), "/DEFAULT.WORKSPACE.SUFFIX"),
+                  ruleClassProvider.getDefaultWorkspaceSuffix(), DEFAULT_WORKSPACE_SUFFIX_FILE),
               // The DEFAULT.WORKSPACE.SUFFIX file breaks through the usual privacy mechanism.
               options.toBuilder().allowLoadPrivateSymbols(true).build());
       if (!file.ok()) {
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
index 15c3c70..e08f783 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
@@ -90,6 +90,7 @@
         "//src/main/java/com/google/devtools/build/lib/analysis:platform_configuration",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform:utils",
+        "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:common",
         "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:resolution",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/packages",
@@ -160,6 +161,7 @@
         "//src/main/java/com/google/devtools/build/lib/analysis:platform_configuration",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform:utils",
+        "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:common",
         "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:exception",
         "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:resolution",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunction.java
index 4fb87c2..671f9f9 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunction.java
@@ -26,6 +26,7 @@
 import com.google.devtools.build.lib.analysis.platform.PlatformProviderUtils;
 import com.google.devtools.build.lib.bazel.bzlmod.BazelDepGraphValue;
 import com.google.devtools.build.lib.bazel.bzlmod.Module;
+import com.google.devtools.build.lib.bazel.bzlmod.ModuleKey;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.LabelConstants;
 import com.google.devtools.build.lib.cmdline.RepositoryName;
@@ -104,21 +105,31 @@
       }
     }
 
-    // Get registered execution platforms from bzlmod.
-    ImmutableList<TargetPattern> bzlmodExecutionPlatforms =
-        getBzlmodExecutionPlatforms(starlarkSemantics, env);
-    if (bzlmodExecutionPlatforms == null) {
+    // Get registered execution platforms from the root Bazel module.
+    ImmutableList<TargetPattern> bzlmodRootModuleExecutionPlatforms =
+        getBzlmodExecutionPlatforms(starlarkSemantics, env, /* forRootModule= */ true);
+    if (bzlmodRootModuleExecutionPlatforms == null) {
       return null;
     }
-    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodExecutionPlatforms));
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodRootModuleExecutionPlatforms));
 
     // Get the registered execution platforms from the WORKSPACE.
+    // The WORKSPACE suffixes don't register any execution platforms, so we can register all
+    // platforms in WORKSPACE before those in non-root Bazel modules.
     ImmutableList<TargetPattern> workspaceExecutionPlatforms = getWorkspaceExecutionPlatforms(env);
     if (workspaceExecutionPlatforms == null) {
       return null;
     }
     targetPatternBuilder.addAll(TargetPatternUtil.toSigned(workspaceExecutionPlatforms));
 
+    // Get registered execution platforms from the non-root Bazel modules.
+    ImmutableList<TargetPattern> bzlmodNonRootModuleExecutionPlatforms =
+        getBzlmodExecutionPlatforms(starlarkSemantics, env, /* forRootModule= */ false);
+    if (bzlmodNonRootModuleExecutionPlatforms == null) {
+      return null;
+    }
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodNonRootModuleExecutionPlatforms));
+
     // Expand target patterns.
     ImmutableList<Label> platformLabels;
     try {
@@ -164,7 +175,7 @@
 
   @Nullable
   private static ImmutableList<TargetPattern> getBzlmodExecutionPlatforms(
-      StarlarkSemantics semantics, Environment env)
+      StarlarkSemantics semantics, Environment env, boolean forRootModule)
       throws InterruptedException, RegisteredExecutionPlatformsFunctionException {
     if (!semantics.getBool(BuildLanguageOptions.ENABLE_BZLMOD)) {
       return ImmutableList.of();
@@ -176,6 +187,9 @@
     }
     ImmutableList.Builder<TargetPattern> executionPlatforms = ImmutableList.builder();
     for (Module module : bazelDepGraphValue.getDepGraph().values()) {
+      if (forRootModule != module.getKey().equals(ModuleKey.ROOT)) {
+        continue;
+      }
       TargetPattern.Parser parser =
           new TargetPattern.Parser(
               PathFragment.EMPTY_FRAGMENT,
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunction.java
index cec918f..c01b3e6 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunction.java
@@ -28,6 +28,7 @@
 import com.google.devtools.build.lib.bazel.bzlmod.BazelDepGraphValue;
 import com.google.devtools.build.lib.bazel.bzlmod.ExternalDepsException;
 import com.google.devtools.build.lib.bazel.bzlmod.Module;
+import com.google.devtools.build.lib.bazel.bzlmod.ModuleKey;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.LabelConstants;
 import com.google.devtools.build.lib.cmdline.RepositoryName;
@@ -95,19 +96,37 @@
           new InvalidToolchainLabelException(e), Transience.PERSISTENT);
     }
 
-    // Get registered toolchains from bzlmod.
-    ImmutableList<TargetPattern> bzlmodToolchains = getBzlmodToolchains(starlarkSemantics, env);
-    if (bzlmodToolchains == null) {
+    // Get registered toolchains from the root Bazel module.
+    ImmutableList<TargetPattern> bzlmodRootModuleToolchains =
+        getBzlmodToolchains(starlarkSemantics, env, /* forRootModule= */ true);
+    if (bzlmodRootModuleToolchains == null) {
       return null;
     }
-    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodToolchains));
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodRootModuleToolchains));
 
-    // Get the registered toolchains from the WORKSPACE.
-    ImmutableList<TargetPattern> workspaceToolchains = getWorkspaceToolchains(env);
-    if (workspaceToolchains == null) {
+    // Get the toolchains from the user-supplied WORKSPACE file.
+    ImmutableList<TargetPattern> userRegisteredWorkspaceToolchains =
+        getWorkspaceToolchains(env, /* userRegistered= */ true);
+    if (userRegisteredWorkspaceToolchains == null) {
       return null;
     }
-    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(workspaceToolchains));
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(userRegisteredWorkspaceToolchains));
+
+    // Get registered toolchains from non-root Bazel modules.
+    ImmutableList<TargetPattern> bzlmodNonRootModuleToolchains =
+        getBzlmodToolchains(starlarkSemantics, env, /* forRootModule= */ false);
+    if (bzlmodNonRootModuleToolchains == null) {
+      return null;
+    }
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodNonRootModuleToolchains));
+
+    // Get the toolchains from the Bazel-supplied WORKSPACE suffix.
+    ImmutableList<TargetPattern> workspaceSuffixToolchains =
+        getWorkspaceToolchains(env, /* userRegistered= */ false);
+    if (workspaceSuffixToolchains == null) {
+      return null;
+    }
+    targetPatternBuilder.addAll(TargetPatternUtil.toSigned(workspaceSuffixToolchains));
 
     // Expand target patterns.
     ImmutableList<Label> toolchainLabels;
@@ -140,8 +159,8 @@
    */
   @Nullable
   @VisibleForTesting
-  public static ImmutableList<TargetPattern> getWorkspaceToolchains(Environment env)
-      throws InterruptedException {
+  public static ImmutableList<TargetPattern> getWorkspaceToolchains(
+      Environment env, boolean userRegistered) throws InterruptedException {
     PackageValue externalPackageValue =
         (PackageValue) env.getValue(LabelConstants.EXTERNAL_PACKAGE_IDENTIFIER);
     if (externalPackageValue == null) {
@@ -149,12 +168,16 @@
     }
 
     Package externalPackage = externalPackageValue.getPackage();
-    return externalPackage.getRegisteredToolchains();
+    if (userRegistered) {
+      return externalPackage.getUserRegisteredToolchains();
+    } else {
+      return externalPackage.getWorkspaceSuffixRegisteredToolchains();
+    }
   }
 
   @Nullable
   private static ImmutableList<TargetPattern> getBzlmodToolchains(
-      StarlarkSemantics semantics, Environment env)
+      StarlarkSemantics semantics, Environment env, boolean forRootModule)
       throws InterruptedException, RegisteredToolchainsFunctionException {
     if (!semantics.getBool(BuildLanguageOptions.ENABLE_BZLMOD)) {
       return ImmutableList.of();
@@ -166,6 +189,9 @@
     }
     ImmutableList.Builder<TargetPattern> toolchains = ImmutableList.builder();
     for (Module module : bazelDepGraphValue.getDepGraph().values()) {
+      if (forRootModule != module.getKey().equals(ModuleKey.ROOT)) {
+        continue;
+      }
       TargetPattern.Parser parser =
           new TargetPattern.Parser(
               PathFragment.EMPTY_FRAGMENT,
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java b/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
index bdcf861..ba0fbd9 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
@@ -501,7 +501,8 @@
         "def rules_java_dependencies():",
         "    pass",
         "def rules_java_toolchains():",
-        "    pass");
+        "    native.register_toolchains('//java/toolchains/runtime:all')",
+        "    native.register_toolchains('//java/toolchains/javac:all')");
 
     config.create(
         "rules_java_workspace/java/toolchains/runtime/BUILD",
diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/LoadingMock.java b/src/test/java/com/google/devtools/build/lib/packages/util/LoadingMock.java
index 9cc9715..bc49f4c 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/util/LoadingMock.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/util/LoadingMock.java
@@ -39,4 +39,8 @@
   public ConfiguredRuleClassProvider createRuleClassProvider() {
     return TestRuleClassProvider.getRuleClassProviderWithClearedSuffix();
   }
+
+  public ConfiguredRuleClassProvider createRuleClassProviderWithSuffix() {
+    return TestRuleClassProvider.getRuleClassProvider();
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/repository/ExternalPackageHelperTest.java b/src/test/java/com/google/devtools/build/lib/repository/ExternalPackageHelperTest.java
index 3e6388a..289d61b 100644
--- a/src/test/java/com/google/devtools/build/lib/repository/ExternalPackageHelperTest.java
+++ b/src/test/java/com/google/devtools/build/lib/repository/ExternalPackageHelperTest.java
@@ -82,6 +82,7 @@
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Stream;
 import javax.annotation.Nullable;
 import org.junit.Before;
 import org.junit.Test;
@@ -411,13 +412,19 @@
     @Override
     public SkyValue compute(SkyKey skyKey, Environment env)
         throws SkyFunctionException, InterruptedException {
-      List<TargetPattern> registeredToolchains =
-          RegisteredToolchainsFunction.getWorkspaceToolchains(env);
-      if (registeredToolchains == null) {
+      ImmutableList<TargetPattern> userRegisteredToolchains =
+          RegisteredToolchainsFunction.getWorkspaceToolchains(env, /* userRegistered= */ true);
+      if (userRegisteredToolchains == null) {
+        return null;
+      }
+      ImmutableList<TargetPattern> workspaceSuffixRegisteredToolchains =
+          RegisteredToolchainsFunction.getWorkspaceToolchains(env, /* userRegistered= */ false);
+      if (workspaceSuffixRegisteredToolchains == null) {
         return null;
       }
       return GetRegisteredToolchainsValue.create(
-          registeredToolchains.stream()
+          Stream.concat(
+                  userRegisteredToolchains.stream(), workspaceSuffixRegisteredToolchains.stream())
               .map(TargetPattern::getOriginalPattern)
               .collect(toImmutableList()));
     }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
index 78e06b2..3e21c1e 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
@@ -224,7 +224,8 @@
     scratch.file(
         "toolchain/BUILD",
         "toolchain_type(name = 'test_toolchain')",
-        "toolchain_type(name = 'optional_toolchain')");
+        "toolchain_type(name = 'optional_toolchain')",
+        "toolchain_type(name = 'workspace_suffix_toolchain')");
 
     testToolchainTypeLabel = Label.parseCanonicalUnchecked("//toolchain:test_toolchain");
     testToolchainType = ToolchainTypeRequirement.create(testToolchainTypeLabel);
@@ -247,6 +248,22 @@
         ImmutableList.of("//constraints:mac"),
         ImmutableList.of("//constraints:linux"),
         "bar");
+    Label suffixToolchainTypeLabel =
+        Label.parseCanonicalUnchecked("//toolchain:workspace_suffix_toolchain");
+    addToolchain(
+        "toolchain",
+        "suffix_toolchain_1",
+        suffixToolchainTypeLabel,
+        ImmutableList.of(),
+        ImmutableList.of(),
+        "suffix1");
+    addToolchain(
+        "toolchain",
+        "suffix_toolchain_2",
+        suffixToolchainTypeLabel,
+        ImmutableList.of(),
+        ImmutableList.of(),
+        "suffix2");
   }
 
   protected EvaluationResult<RegisteredToolchainsValue> requestToolchainsFromSkyframe(
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
index 943c9dd..08252c3 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/BUILD
@@ -66,14 +66,10 @@
     deps = [
         "//src/main/java/com/google/devtools/build/lib/analysis:view_creation_failed_exception",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform",
-        "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:resolution_impl",
-        "//src/main/java/com/google/devtools/build/lib/bazel/repository:repository_options",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/skyframe:configured_target_key",
-        "//src/main/java/com/google/devtools/build/lib/skyframe:precomputed_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe/toolchains:platform_lookup_util",
         "//src/main/java/com/google/devtools/build/lib/skyframe/toolchains:registered_execution_platforms_value",
-        "//src/main/java/com/google/devtools/build/lib/vfs",
         "//src/main/java/com/google/devtools/build/skyframe",
         "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//src/test/java/com/google/devtools/build/lib/bazel/bzlmod:util",
@@ -92,17 +88,16 @@
     name = "RegisteredToolchainsFunctionTest",
     srcs = ["RegisteredToolchainsFunctionTest.java"],
     deps = [
+        "//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
         "//src/main/java/com/google/devtools/build/lib/analysis/platform",
-        "//src/main/java/com/google/devtools/build/lib/bazel/bzlmod:resolution_impl",
-        "//src/main/java/com/google/devtools/build/lib/bazel/repository:repository_options",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
-        "//src/main/java/com/google/devtools/build/lib/skyframe:precomputed_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe/toolchains:registered_toolchains_value",
         "//src/main/java/com/google/devtools/build/lib/vfs",
         "//src/main/java/com/google/devtools/build/skyframe",
         "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//src/test/java/com/google/devtools/build/lib/bazel/bzlmod:util",
         "//src/test/java/com/google/devtools/build/lib/rules/platform:testutil",
+        "//src/test/java/com/google/devtools/build/lib/testutil",
         "//src/test/java/com/google/devtools/build/skyframe:testutil",
         "//third_party:guava",
         "//third_party:guava-testlib",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunctionTest.java
index 5d9688f..91b3981 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredExecutionPlatformsFunctionTest.java
@@ -339,8 +339,13 @@
           "platform(name='plat')");
     }
     scratch.overwriteFile(
-        "BUILD", "platform(name='plat')", "platform(name='dev_plat')", "platform(name='wsplat')");
-    rewriteWorkspace("register_execution_platforms('//:wsplat')");
+        "BUILD",
+        "platform(name='plat')",
+        "platform(name='dev_plat')",
+        "platform(name='wsplat')",
+        "platform(name='wsplat2')");
+    rewriteWorkspace(
+        "register_execution_platforms('//:wsplat')", "register_execution_platforms('//:wsplat2')");
 
     SkyKey executionPlatformsKey = RegisteredExecutionPlatformsValue.key(targetConfigKey);
     EvaluationResult<RegisteredExecutionPlatformsValue> result =
@@ -354,13 +359,17 @@
     // WORKSPACE registrations.
     assertExecutionPlatformLabels(result.get(executionPlatformsKey))
         .containsExactly(
+            // Root module platforms
             Label.parseCanonical("//:plat"),
             Label.parseCanonical("//:dev_plat"),
+            // WORKSPACE platforms
+            Label.parseCanonical("//:wsplat"),
+            Label.parseCanonical("//:wsplat2"),
+            // Other modules' toolchains
             Label.parseCanonical("@@bbb~1.0//:plat"),
             Label.parseCanonical("@@ccc~1.1//:plat"),
             Label.parseCanonical("@@eee~1.0//:plat"),
-            Label.parseCanonical("@@ddd~1.1//:plat"),
-            Label.parseCanonical("//:wsplat"))
+            Label.parseCanonical("@@ddd~1.1//:plat"))
         .inOrder();
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunctionTest.java
index 0892a30..15b934e 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/toolchains/RegisteredToolchainsFunctionTest.java
@@ -20,15 +20,18 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.testing.EqualsTester;
+import com.google.devtools.build.lib.analysis.ConfiguredRuleClassProvider;
 import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
 import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.rules.platform.ToolchainTestCase;
+import com.google.devtools.build.lib.testutil.TestRuleClassProvider;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.skyframe.EvaluationResult;
 import com.google.devtools.build.skyframe.SkyKey;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -37,6 +40,17 @@
 @RunWith(JUnit4.class)
 public class RegisteredToolchainsFunctionTest extends ToolchainTestCase {
 
+  @Override
+  protected ConfiguredRuleClassProvider createRuleClassProvider() {
+    // testRegisteredToolchains_bzlmod uses the WORKSPACE suffixes.
+    ConfiguredRuleClassProvider.Builder builder = new ConfiguredRuleClassProvider.Builder();
+    TestRuleClassProvider.addStandardRules(builder);
+    builder.clearWorkspaceFileSuffixForTesting();
+    builder.addWorkspaceFileSuffix(
+        "register_toolchains('//toolchain:suffix_toolchain_1', '//toolchain:suffix_toolchain_2')");
+    return builder.build();
+  }
+
   @Test
   public void testRegisteredToolchains() throws Exception {
     // Request the toolchains.
@@ -423,8 +437,24 @@
         "load('@toolchain_def//:toolchain_def.bzl', 'declare_toolchain')",
         "declare_toolchain(name='dev_tool')",
         "declare_toolchain(name='tool')",
-        "declare_toolchain(name='wstool')");
-    rewriteWorkspace("register_toolchains('//:wstool')");
+        "declare_toolchain(name='wstool')",
+        "declare_toolchain(name='wstool2')");
+    scratch.overwriteFile(
+        "WORKSPACE",
+        Stream.concat(
+                analysisMock.getWorkspaceContents(mockToolsConfig).stream()
+                    // The register_toolchains calls usually live in the WORKSPACE suffixes.
+                    // BazelAnalysisMock moves the mock registrations to the actual WORKSPACE file
+                    // as most Java tests don't run with the suffixes. This test class does, so we
+                    // skip over the "unnatural" registrations.
+                    .filter(line -> !line.startsWith("register_toolchains(")),
+                // Register a toolchain explicitly that is also registered in the WORKSPACE suffix.
+                Stream.of(
+                    "register_toolchains('//:wstool')",
+                    "register_toolchains('//toolchain:suffix_toolchain_2')",
+                    "register_toolchains('//:wstool2')"))
+            .toArray(String[]::new));
+    invalidatePackages();
 
     SkyKey toolchainsKey = RegisteredToolchainsValue.key(targetConfigKey);
     EvaluationResult<RegisteredToolchainsValue> result =
@@ -438,13 +468,20 @@
     // registrations.
     assertToolchainLabels(result.get(toolchainsKey))
         .containsAtLeast(
+            // Root module toolchains
             Label.parseCanonical("//:tool_impl"),
             Label.parseCanonical("//:dev_tool_impl"),
+            // WORKSPACE toolchains
+            Label.parseCanonical("//:wstool_impl"),
+            Label.parseCanonical("//toolchain:suffix_toolchain_2_impl"),
+            Label.parseCanonical("//:wstool2_impl"),
+            // Other modules' toolchains
             Label.parseCanonical("@@bbb~1.0//:tool_impl"),
             Label.parseCanonical("@@ccc~1.1//:tool_impl"),
             Label.parseCanonical("@@eee~1.0//:tool_impl"),
             Label.parseCanonical("@@ddd~1.1//:tool_impl"),
-            Label.parseCanonical("//:wstool_impl"))
+            // WORKSPACE suffix toolchains
+            Label.parseCanonical("//toolchain:suffix_toolchain_1_impl"))
         .inOrder();
   }