[6.5.0] Proto toolchainisation cherrypicks (#20925)

Fixes: https://github.com/bazelbuild/bazel/issues/20921

List of cherrypicks:
5a24c8a3c6 ilist@google.com Mon Oct 9 02:27:10 2023 -0700 Support
automatic exec groups in proto_common.compile
3e1e06164f ilist@google.com Tue Oct 3 03:32:01 2023 -0700 Use proto
toolchains in cc_proto_library
8c38be3230 ilist@google.com Mon Oct 2 07:08:55 2023 -0700 Use proto
toolchains in java_lite_proto_library
51970d25d9 ilist@google.com Mon Oct 2 07:05:16 2023 -0700 Use proto
toolchains in java_proto_library
20bc11facc ilist@google.com Mon Oct 2 06:56:47 2023 -0700 Decouple
java_lite_proto_library from java_proto_library
3b18d3fbe8 ilist@google.com Mon Oct 2 04:11:24 2023 -0700 Refactor proto
toolchainsation support utilities
42800a8292 ilist@google.com Fri Sep 22 03:04:18 2023 -0700 Use proto
compiler from proto_toolchain rule
d435c6dd6e ilist@google.com Thu Sep 21 07:05:40 2023 -0700 Use
proto_toolchain in proto_library
f5fb2f6a5f ilist@google.com Wed Sep 20 05:50:59 2023 -0700 Remove protoc
from proto_lang_toolchain rule
108ef553d7 ilist@google.com Tue Sep 19 08:40:00 2023 -0700 Use
MockProtoSupport.setup where protos are used
11cf1b756b ilist@google.com Sun Sep 17 21:57:54 2023 -0700 Implement
incompatible_enable_proto_toolchain_resolution
diff --git a/src/main/java/com/google/devtools/build/lib/packages/semantics/BuildLanguageOptions.java b/src/main/java/com/google/devtools/build/lib/packages/semantics/BuildLanguageOptions.java
index 4f9d088..f75c7bf 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/semantics/BuildLanguageOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/semantics/BuildLanguageOptions.java
@@ -696,6 +696,18 @@
               + " Label.relative) can be used.")
   public boolean enableDeprecatedLabelApis;
 
+  // Flip when dependencies to rules_* repos are upgraded and protobuf registers toolchains
+  @Option(
+      name = "incompatible_enable_proto_toolchain_resolution",
+      defaultValue = "false",
+      documentationCategory = OptionDocumentationCategory.TOOLCHAIN,
+      effectTags = {OptionEffectTag.LOADING_AND_ANALYSIS},
+      metadataTags = {OptionMetadataTag.INCOMPATIBLE_CHANGE},
+      help =
+          "If true, proto lang rules define toolchains from rules_proto, rules_java, rules_cc"
+              + " repositories.")
+  public boolean incompatibleEnableProtoToolchainResolution;
+
   /**
    * An interner to reduce the number of StarlarkSemantics instances. A single Blaze instance should
    * never accumulate a large number of these and being able to shortcut on object identity makes a
@@ -795,6 +807,9 @@
                 INCOMPATIBLE_DISABLE_OBJC_LIBRARY_TRANSITION,
                 incompatibleDisableObjcLibraryTransition)
             .setBool(INCOMPATIBLE_ENABLE_DEPRECATED_LABEL_APIS, enableDeprecatedLabelApis)
+            .setBool(
+                INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION,
+                incompatibleEnableProtoToolchainResolution)
             .build();
     return INTERNER.intern(semantics);
   }
@@ -891,6 +906,8 @@
       "-incompatible_disable_objc_library_transition";
   public static final String INCOMPATIBLE_ENABLE_DEPRECATED_LABEL_APIS =
       "+incompatible_enable_deprecated_label_apis";
+  public static final String INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION =
+      "-incompatible_enable_proto_toolchain_resolution";
 
   // non-booleans
   public static final StarlarkSemantics.Key<String> EXPERIMENTAL_BUILTINS_BZL_PATH =
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/BUILD b/src/main/java/com/google/devtools/build/lib/rules/proto/BUILD
index af13958..d79f083 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/BUILD
@@ -40,6 +40,7 @@
         "//src/main/java/com/google/devtools/build/lib/collect/nestedset",
         "//src/main/java/com/google/devtools/build/lib/concurrent",
         "//src/main/java/com/google/devtools/build/lib/packages",
+        "//src/main/java/com/google/devtools/build/lib/packages/semantics",
         "//src/main/java/com/google/devtools/build/lib/starlarkbuildapi",
         "//src/main/java/com/google/devtools/build/lib/starlarkbuildapi/proto",
         "//src/main/java/com/google/devtools/build/lib/util:filetype",
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java b/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java
index e413803..77ac10a 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java
@@ -22,6 +22,12 @@
 import net.starlark.java.annot.StarlarkMethod;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.StarlarkList;
+import com.google.common.collect.ImmutableSet;
+import com.google.devtools.build.lib.packages.BuiltinRestriction;
+import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
+import com.google.devtools.build.lib.starlarkbuildapi.proto.ProtoCommonApi;
+import net.starlark.java.annot.StarlarkMethod;
+import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.StarlarkThread;
 
 /** Protocol buffers support for Starlark. */
@@ -87,4 +93,15 @@
         Depset.cast(transitiveDescriptorSets, Artifact.class, "transitive_descriptor_set"),
         Depset.cast(exportedSources, ProtoSource.class, "exported_sources"));
   }
+
+  @StarlarkMethod(
+      name = "incompatible_enable_proto_toolchain_resolution",
+      useStarlarkThread = true,
+      documented = false)
+  public boolean getDefineProtoToolchains(StarlarkThread thread) throws EvalException {
+    ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
+    return thread
+        .getSemantics()
+        .getBool(BuildLanguageOptions.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION);
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java
index c9d77f3..9582ed9 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java
@@ -31,6 +31,7 @@
 import com.google.devtools.common.options.OptionEffectTag;
 import com.google.devtools.common.options.OptionMetadataTag;
 import java.util.List;
+import javax.annotation.Nullable;
 import net.starlark.java.annot.StarlarkMethod;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.StarlarkThread;
@@ -65,13 +66,12 @@
     public List<String> protocOpts;
 
     @Option(
-      name = "experimental_proto_extra_actions",
-      defaultValue = "false",
-      documentationCategory = OptionDocumentationCategory.OUTPUT_SELECTION,
-      effectTags = {OptionEffectTag.AFFECTS_OUTPUTS, OptionEffectTag.LOADING_AND_ANALYSIS},
-      metadataTags = {OptionMetadataTag.EXPERIMENTAL},
-      help = "Run extra actions for alternative Java api versions in a proto_library."
-    )
+        name = "experimental_proto_extra_actions",
+        defaultValue = "false",
+        documentationCategory = OptionDocumentationCategory.OUTPUT_SELECTION,
+        effectTags = {OptionEffectTag.AFFECTS_OUTPUTS, OptionEffectTag.LOADING_AND_ANALYSIS},
+        metadataTags = {OptionMetadataTag.EXPERIMENTAL},
+        help = "Run extra actions for alternative Java api versions in a proto_library.")
     public boolean experimentalProtoExtraActions;
 
     @Option(
@@ -240,8 +240,8 @@
   }
 
   /**
-   * Returns true if we will run extra actions for actions that are not run by default. If this
-   * is enabled, e.g. all extra_actions for alternative api-versions or language-flavours of a
+   * Returns true if we will run extra actions for actions that are not run by default. If this is
+   * enabled, e.g. all extra_actions for alternative api-versions or language-flavours of a
    * proto_library target are run.
    */
   public boolean runExperimentalProtoExtraActions() {
@@ -252,6 +252,7 @@
       name = "proto_compiler",
       doc = "Label for the proto compiler.",
       defaultLabel = ProtoConstants.DEFAULT_PROTOC_LABEL)
+  @Nullable
   public Label protoCompiler() {
     return options.protoCompiler;
   }
@@ -260,10 +261,12 @@
       name = "proto_toolchain_for_java",
       doc = "Label for the java proto toolchains.",
       defaultLabel = ProtoConstants.DEFAULT_JAVA_PROTO_LABEL)
+  @Nullable
   public Label protoToolchainForJava() {
     return options.protoToolchainForJava;
   }
 
+  @Nullable
   public Label protoToolchainForJ2objc() {
     return options.protoToolchainForJ2objc;
   }
@@ -272,6 +275,7 @@
       name = "proto_toolchain_for_java_lite",
       doc = "Label for the java lite proto toolchains.",
       defaultLabel = ProtoConstants.DEFAULT_JAVA_LITE_PROTO_LABEL)
+  @Nullable
   public Label protoToolchainForJavaLite() {
     return options.protoToolchainForJavaLite;
   }
@@ -280,6 +284,7 @@
       name = "proto_toolchain_for_cc",
       doc = "Label for the cc proto toolchains.",
       defaultLabel = ProtoConstants.DEFAULT_CC_PROTO_LABEL)
+  @Nullable
   public Label protoToolchainForCc() {
     return options.protoToolchainForCc;
   }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java
index 80ee27b..e08307d 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConstants.java
@@ -17,7 +17,7 @@
 /** Constants used in Proto rules. */
 public final class ProtoConstants {
   /** Default label for proto compiler. */
-  static final String DEFAULT_PROTOC_LABEL =  "@bazel_tools//tools/proto:protoc";
+  public static final String DEFAULT_PROTOC_LABEL =  "@bazel_tools//tools/proto:protoc";
 
   /** Default label for java proto toolchains. */
   static final String DEFAULT_JAVA_PROTO_LABEL = "@bazel_tools//tools/proto:java_toolchain";
diff --git a/src/main/starlark/builtins_bzl/common/cc/cc_proto_library.bzl b/src/main/starlark/builtins_bzl/common/cc/cc_proto_library.bzl
index fa7f896..956b975 100644
--- a/src/main/starlark/builtins_bzl/common/cc/cc_proto_library.bzl
+++ b/src/main/starlark/builtins_bzl/common/cc/cc_proto_library.bzl
@@ -15,7 +15,8 @@
 """Starlark implementation of cc_proto_library"""
 
 load(":common/cc/cc_helper.bzl", "cc_helper")
-load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
+load(":common/proto/proto_common.bzl", "toolchains", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
+load(":common/cc/semantics.bzl", "semantics")
 
 ProtoInfo = _builtins.toplevel.ProtoInfo
 CcInfo = _builtins.toplevel.CcInfo
@@ -24,15 +25,15 @@
 ProtoCcFilesInfo = provider(fields = ["files"], doc = "Provide cc proto files.")
 ProtoCcHeaderInfo = provider(fields = ["headers"], doc = "Provide cc proto headers.")
 
-def _are_srcs_excluded(ctx, target):
-    return not proto_common.experimental_should_generate_code(target[ProtoInfo], ctx.attr._aspect_cc_proto_toolchain[ProtoLangToolchainInfo], "cc_proto_library", target.label)
+def _are_srcs_excluded(proto_toolchain, target):
+    return not proto_common.experimental_should_generate_code(target[ProtoInfo], proto_toolchain, "cc_proto_library", target.label)
 
-def _get_feature_configuration(ctx, target, cc_toolchain, proto_info):
+def _get_feature_configuration(ctx, target, cc_toolchain, proto_info, proto_toolchain):
     requested_features = list(ctx.features)
     unsupported_features = list(ctx.disabled_features)
     unsupported_features.append("parse_headers")
     unsupported_features.append("layering_check")
-    if not _are_srcs_excluded(ctx, target) and len(proto_info.direct_sources) != 0:
+    if not _are_srcs_excluded(proto_toolchain, target) and len(proto_info.direct_sources) != 0:
         requested_features.append("header_modules")
     else:
         unsupported_features.append("header_modules")
@@ -78,12 +79,13 @@
 
 def _aspect_impl(target, ctx):
     cc_toolchain = cc_helper.find_cpp_toolchain(ctx)
+    proto_toolchain = toolchains.find_toolchain(ctx, "_aspect_cc_proto_toolchain", semantics.CC_PROTO_TOOLCHAIN)
     proto_info = target[ProtoInfo]
-    feature_configuration = _get_feature_configuration(ctx, target, cc_toolchain, proto_info)
+    feature_configuration = _get_feature_configuration(ctx, target, cc_toolchain, proto_info, proto_toolchain)
     proto_configuration = ctx.fragments.proto
 
     deps = []
-    runtime = ctx.attr._aspect_cc_proto_toolchain[ProtoLangToolchainInfo].runtime
+    runtime = proto_toolchain.runtime
     if runtime != None:
         deps.append(runtime)
     deps.extend(getattr(ctx.rule.attr, "deps", []))
@@ -100,7 +102,7 @@
     textual_hdrs = []
     additional_exported_hdrs = []
 
-    if _are_srcs_excluded(ctx, target):
+    if _are_srcs_excluded(proto_toolchain, target):
         header_provider = None
 
         # Hack: This is a proto_library for descriptor.proto or similar.
@@ -155,7 +157,7 @@
     proto_common.compile(
         actions = ctx.actions,
         proto_info = proto_info,
-        proto_lang_toolchain_info = ctx.attr._aspect_cc_proto_toolchain[ProtoLangToolchainInfo],
+        proto_lang_toolchain_info = proto_toolchain,
         generated_files = outputs,
         experimental_output_files = "multiple",
     )
@@ -250,12 +252,11 @@
     required_providers = [ProtoInfo],
     provides = [CcInfo],
     attrs = {
-        "_aspect_cc_proto_toolchain": attr.label(
-            default = configuration_field(fragment = "proto", name = "proto_toolchain_for_cc"),
-        ),
         "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"),
-    },
-    toolchains = cc_helper.use_cpp_toolchain(),
+    } | toolchains.if_legacy_toolchain({"_aspect_cc_proto_toolchain": attr.label(
+        default = configuration_field(fragment = "proto", name = "proto_toolchain_for_cc"),
+    )}),
+    toolchains = cc_helper.use_cpp_toolchain() + toolchains.use_toolchain(semantics.CC_PROTO_TOOLCHAIN),
 )
 
 def _impl(ctx):
@@ -283,4 +284,5 @@
         ),
     },
     provides = [CcInfo],
+    toolchains = toolchains.use_toolchain(semantics.CC_PROTO_TOOLCHAIN),
 )
diff --git a/src/main/starlark/builtins_bzl/common/cc/semantics.bzl b/src/main/starlark/builtins_bzl/common/cc/semantics.bzl
index 14cb4f4..9f8e833 100644
--- a/src/main/starlark/builtins_bzl/common/cc/semantics.bzl
+++ b/src/main/starlark/builtins_bzl/common/cc/semantics.bzl
@@ -198,4 +198,5 @@
     get_coverage_env = _get_coverage_env,
     get_proto_aspects = _get_proto_aspects,
     incompatible_disable_objc_library_transition = _incompatible_disable_objc_library_transition,
+    CC_PROTO_TOOLCHAIN = "@rules_cc//cc/proto:toolchain_type",
 )
diff --git a/src/main/starlark/builtins_bzl/common/exports.bzl b/src/main/starlark/builtins_bzl/common/exports.bzl
index 42ab667..a2db972 100755
--- a/src/main/starlark/builtins_bzl/common/exports.bzl
+++ b/src/main/starlark/builtins_bzl/common/exports.bzl
@@ -24,7 +24,7 @@
 load("@_builtins//:common/objc/linking_support.bzl", "linking_support")
 load("@_builtins//:common/proto/proto_common.bzl", "proto_common_do_not_use")
 load("@_builtins//:common/proto/proto_library.bzl", "proto_library")
-load("@_builtins//:common/proto/proto_lang_toolchain_wrapper.bzl", "proto_lang_toolchain")
+load("@_builtins//:common/proto/proto_lang_toolchain.bzl", "proto_lang_toolchain")
 load("@_builtins//:common/python/py_runtime_macro.bzl", "py_runtime")
 load("@_builtins//:common/python/providers.bzl", "PyInfo", "PyRuntimeInfo")
 load("@_builtins//:common/java/proto/java_lite_proto_library.bzl", "java_lite_proto_library")
diff --git a/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl b/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
index 923af60..f86cf31 100644
--- a/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
+++ b/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
@@ -73,4 +73,6 @@
     check_proto_registry_collision = _check_proto_registry_collision,
     get_coverage_runner = _get_coverage_runner,
     add_constraints = _add_constraints,
+    JAVA_PROTO_TOOLCHAIN = "@rules_java//java/proto:toolchain_type",
+    JAVA_LITE_PROTO_TOOLCHAIN = "@rules_java//java/proto:lite_toolchain_type",
 )
diff --git a/src/main/starlark/builtins_bzl/common/java/proto/java_lite_proto_library.bzl b/src/main/starlark/builtins_bzl/common/java/proto/java_lite_proto_library.bzl
index 56afa3f..6f08afd 100644
--- a/src/main/starlark/builtins_bzl/common/java/proto/java_lite_proto_library.bzl
+++ b/src/main/starlark/builtins_bzl/common/java/proto/java_lite_proto_library.bzl
@@ -15,8 +15,8 @@
 """A Starlark implementation of the java_lite_proto_library rule."""
 
 load(":common/java/java_semantics.bzl", "semantics")
-load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
-load(":common/java/proto/java_proto_library.bzl", "JavaProtoAspectInfo", "bazel_java_proto_library_rule", "java_compile_for_protos")
+load(":common/java/proto/java_proto_library.bzl", "JavaProtoAspectInfo", "java_compile_for_protos")
+load(":common/proto/proto_common.bzl", "toolchains", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
 
 PROTO_TOOLCHAIN_ATTR = "_aspect_proto_toolchain_for_javalite"
 
@@ -40,7 +40,11 @@
 
     deps = [dep[JavaInfo] for dep in ctx.rule.attr.deps]
     exports = [exp[JavaInfo] for exp in ctx.rule.attr.exports]
-    proto_toolchain_info = ctx.attr._aspect_proto_toolchain_for_javalite[ProtoLangToolchainInfo]
+    proto_toolchain_info = toolchains.find_toolchain(
+        ctx,
+        "_aspect_proto_toolchain_for_javalite",
+        semantics.JAVA_LITE_PROTO_TOOLCHAIN,
+    )
     source_jar = None
 
     if proto_common.experimental_should_generate_code(target[ProtoInfo], proto_toolchain_info, "java_lite_proto_library", target.label):
@@ -74,15 +78,16 @@
 java_lite_proto_aspect = aspect(
     implementation = _aspect_impl,
     attr_aspects = ["deps", "exports"],
-    attrs = {
+    attrs = toolchains.if_legacy_toolchain({
         PROTO_TOOLCHAIN_ATTR: attr.label(
             default = configuration_field(fragment = "proto", name = "proto_toolchain_for_java_lite"),
         ),
-    },
+    }),
     fragments = ["java"],
     required_providers = [ProtoInfo],
     provides = [JavaInfo, JavaProtoAspectInfo],
-    toolchains = [semantics.JAVA_TOOLCHAIN],
+    toolchains = [semantics.JAVA_TOOLCHAIN] +
+                 toolchains.use_toolchain(semantics.JAVA_LITE_PROTO_TOOLCHAIN),
 )
 
 def _rule_impl(ctx):
@@ -98,7 +103,11 @@
       ([JavaInfo, DefaultInfo, OutputGroupInfo, ProguardSpecProvider])
     """
 
-    proto_toolchain_info = ctx.attr._aspect_proto_toolchain_for_javalite[ProtoLangToolchainInfo]
+    proto_toolchain_info = toolchains.find_toolchain(
+        ctx,
+        "_aspect_proto_toolchain_for_javalite",
+        semantics.JAVA_LITE_PROTO_TOOLCHAIN,
+    )
     runtime = proto_toolchain_info.runtime
 
     if runtime:
@@ -106,13 +115,20 @@
     else:
         proguard_provider_specs = ProguardSpecProvider(depset())
 
-    java_info, DefaultInfo, OutputGroupInfo = bazel_java_proto_library_rule(ctx)
+    java_info = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps], merge_java_outputs = False)
+
+    transitive_src_and_runtime_jars = depset(transitive = [dep[JavaProtoAspectInfo].jars for dep in ctx.attr.deps])
+    transitive_runtime_jars = depset(transitive = [java_info.transitive_runtime_jars])
+
     java_info = semantics.add_constraints(java_info, ["android"])
 
     return [
         java_info,
-        DefaultInfo,
-        OutputGroupInfo,
+        DefaultInfo(
+            files = transitive_src_and_runtime_jars,
+            runfiles = ctx.runfiles(transitive_files = transitive_runtime_jars),
+        ),
+        OutputGroupInfo(default = depset()),
         proguard_provider_specs,
     ]
 
@@ -120,9 +136,11 @@
     implementation = _rule_impl,
     attrs = {
         "deps": attr.label_list(providers = [ProtoInfo], aspects = [java_lite_proto_aspect]),
+    } | toolchains.if_legacy_toolchain({
         PROTO_TOOLCHAIN_ATTR: attr.label(
             default = configuration_field(fragment = "proto", name = "proto_toolchain_for_java_lite"),
         ),
-    },
+    }),
     provides = [JavaInfo],
+    toolchains = toolchains.use_toolchain(semantics.JAVA_LITE_PROTO_TOOLCHAIN),
 )
diff --git a/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl b/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
index 09690f0..f2634ec 100644
--- a/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
+++ b/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
@@ -15,7 +15,7 @@
 """The implementation of the `java_proto_library` rule and its aspect."""
 
 load(":common/java/java_semantics.bzl", "semantics")
-load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
+load(":common/proto/proto_common.bzl", "toolchains", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
 
 java_common = _builtins.toplevel.java_common
 JavaInfo = _builtins.toplevel.JavaInfo
@@ -47,7 +47,7 @@
       runtime jars.
     """
 
-    proto_toolchain_info = ctx.attr._aspect_java_proto_toolchain[ProtoLangToolchainInfo]
+    proto_toolchain_info = toolchains.find_toolchain(ctx, "_aspect_java_proto_toolchain", semantics.JAVA_PROTO_TOOLCHAIN)
     source_jar = None
     if proto_common.experimental_should_generate_code(target[ProtoInfo], proto_toolchain_info, "java_proto_library", target.label):
         # Generate source jar using proto compiler.
@@ -129,12 +129,12 @@
 
 bazel_java_proto_aspect = aspect(
     implementation = _bazel_java_proto_aspect_impl,
-    attrs = {
+    attrs = toolchains.if_legacy_toolchain({
         "_aspect_java_proto_toolchain": attr.label(
             default = configuration_field(fragment = "proto", name = "proto_toolchain_for_java"),
         ),
-    },
-    toolchains = [semantics.JAVA_TOOLCHAIN],
+    }),
+    toolchains = [semantics.JAVA_TOOLCHAIN] + toolchains.use_toolchain(semantics.JAVA_PROTO_TOOLCHAIN),
     attr_aspects = ["deps", "exports"],
     required_providers = [ProtoInfo],
     provides = [JavaInfo, JavaProtoAspectInfo],
@@ -149,7 +149,6 @@
     Returns:
       ([JavaInfo, DefaultInfo, OutputGroupInfo])
     """
-
     java_info = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps], merge_java_outputs = False)
 
     transitive_src_and_runtime_jars = depset(transitive = [dep[JavaProtoAspectInfo].jars for dep in ctx.attr.deps])
@@ -172,4 +171,5 @@
         "distribs": attr.string_list(),
     },
     provides = [JavaInfo],
+    toolchains = toolchains.use_toolchain(semantics.JAVA_PROTO_TOOLCHAIN),
 )
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
index 64036dd..062dad3 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
@@ -30,6 +30,7 @@
         protoc_opts = "(list[str]) Options to pass to proto compiler.",
         progress_message = "(str) Progress message to set on the proto compiler action.",
         mnemonic = "(str) Mnemonic to set on the proto compiler action.",
+        toolchain_type = """(Label) Toolchain type that was used to obtain this info""",
     ),
 )
 
@@ -154,6 +155,7 @@
         use_default_shell_env = True,
         resource_set = resource_set,
         exec_group = experimental_exec_group,
+        toolchain = getattr(proto_lang_toolchain_info, "toolchain_type", None),
     )
 
 _BAZEL_TOOLS_PREFIX = "external/bazel_tools/"
@@ -265,10 +267,39 @@
 
     return outputs
 
+def _find_toolchain(ctx, legacy_attr, toolchain_type):
+    if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution():
+        toolchain = ctx.toolchains[toolchain_type]
+        if not toolchain:
+            fail("No toolchains registered for '%s'." % toolchain_type)
+        return toolchain.proto
+    else:
+        return getattr(ctx.attr, legacy_attr)[ProtoLangToolchainInfo]
+
+def _use_toolchain(toolchain_type):
+    if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution():
+        return [_builtins.toplevel.config_common.toolchain_type(toolchain_type, mandatory = False)]
+    else:
+        return []
+
+def _if_legacy_toolchain(legacy_attr_dict):
+    if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution():
+        return {}
+    else:
+        return legacy_attr_dict
+
+toolchains = struct(
+    use_toolchain = _use_toolchain,
+    find_toolchain = _find_toolchain,
+    if_legacy_toolchain = _if_legacy_toolchain,
+)
+
 proto_common_do_not_use = struct(
     compile = _compile,
     declare_generated_files = _declare_generated_files,
     experimental_should_generate_code = _experimental_should_generate_code,
     experimental_filter_sources = _experimental_filter_sources,
     ProtoLangToolchainInfo = ProtoLangToolchainInfo,
+    INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(),
+    INCOMPATIBLE_PASS_TOOLCHAIN_TYPE = True,
 )
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
index 3d2a9d0..6996369 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
@@ -14,7 +14,7 @@
 
 """A Starlark implementation of the proto_lang_toolchain rule."""
 
-load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo")
+load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains", proto_common = "proto_common_do_not_use")
 load(":common/proto/proto_semantics.bzl", "semantics")
 
 ProtoInfo = _builtins.toplevel.ProtoInfo
@@ -31,61 +31,60 @@
     if ctx.attr.plugin != None:
         plugin = ctx.attr.plugin[DefaultInfo].files_to_run
 
-    proto_compiler = getattr(ctx.attr, "proto_compiler", None)
-    proto_compiler = getattr(ctx.attr, "_proto_compiler", proto_compiler)
+    if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
+        proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.proto_compiler
+        protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.protoc_opts
+    else:
+        proto_compiler = ctx.attr._proto_compiler.files_to_run
+        protoc_opts = ctx.fragments.proto.experimental_protoc_opts
 
+    proto_lang_toolchain_info = ProtoLangToolchainInfo(
+        out_replacement_format_flag = flag,
+        output_files = ctx.attr.output_files,
+        plugin_format_flag = ctx.attr.plugin_format_flag,
+        plugin = plugin,
+        runtime = ctx.attr.runtime,
+        provided_proto_sources = provided_proto_sources,
+        proto_compiler = proto_compiler,
+        protoc_opts = protoc_opts,
+        progress_message = ctx.attr.progress_message,
+        mnemonic = ctx.attr.mnemonic,
+        toolchain_type = ctx.attr.toolchain_type.label if ctx.attr.toolchain_type else None,
+    )
     return [
-        DefaultInfo(
-            files = depset(),
-            runfiles = ctx.runfiles(),
-        ),
-        ProtoLangToolchainInfo(
-            out_replacement_format_flag = flag,
-            output_files = ctx.attr.output_files,
-            plugin_format_flag = ctx.attr.plugin_format_flag,
-            plugin = plugin,
-            runtime = ctx.attr.runtime,
-            provided_proto_sources = provided_proto_sources,
-            proto_compiler = proto_compiler.files_to_run,
-            protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
-            progress_message = ctx.attr.progress_message,
-            mnemonic = ctx.attr.mnemonic,
-        ),
+        DefaultInfo(files = depset(), runfiles = ctx.runfiles()),
+        _builtins.toplevel.platform_common.ToolchainInfo(proto = proto_lang_toolchain_info),
+        # TODO(b/300592942): remove when --incompatible_enable_proto_toolchains is flipped and removed
+        proto_lang_toolchain_info,
     ]
 
-def make_proto_lang_toolchain(custom_proto_compiler):
-    return rule(
-        _rule_impl,
-        attrs = dict(
-            {
-                "progress_message": attr.string(default = "Generating proto_library %{label}"),
-                "mnemonic": attr.string(default = "GenProto"),
-                "command_line": attr.string(mandatory = True),
-                "output_files": attr.string(values = ["single", "multiple", "legacy"], default = "legacy"),
-                "plugin_format_flag": attr.string(),
-                "plugin": attr.label(
-                    executable = True,
-                    cfg = "exec",
-                ),
-                "runtime": attr.label(),
-                "blacklisted_protos": attr.label_list(
-                    providers = [ProtoInfo],
-                ),
-            },
-            **({
-                "proto_compiler": attr.label(
-                    cfg = "exec",
-                    executable = True,
-                ),
-            } if custom_proto_compiler else {
-                "_proto_compiler": attr.label(
-                    cfg = "exec",
-                    executable = True,
-                    allow_files = True,
-                    default = configuration_field("proto", "proto_compiler"),
-                ),
-            })
-        ),
-        provides = [ProtoLangToolchainInfo],
-        fragments = ["proto"] + semantics.EXTRA_FRAGMENTS,
-    )
+proto_lang_toolchain = rule(
+    _rule_impl,
+    attrs =
+        {
+            "progress_message": attr.string(default = "Generating proto_library %{label}"),
+            "mnemonic": attr.string(default = "GenProto"),
+            "command_line": attr.string(mandatory = True),
+            "output_files": attr.string(values = ["single", "multiple", "legacy"], default = "legacy"),
+            "plugin_format_flag": attr.string(),
+            "plugin": attr.label(
+                executable = True,
+                cfg = "exec",
+            ),
+            "runtime": attr.label(),
+            "blacklisted_protos": attr.label_list(
+                providers = [ProtoInfo],
+            ),
+            "toolchain_type": attr.label(),
+        } | ({} if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else {
+            "_proto_compiler": attr.label(
+                cfg = "exec",
+                executable = True,
+                allow_files = True,
+                default = configuration_field("proto", "proto_compiler"),
+            ),
+        }),
+    provides = [ProtoLangToolchainInfo],
+    fragments = ["proto"],
+    toolchains = toolchains.use_toolchain(semantics.PROTO_TOOLCHAIN),  # Used to obtain protoc
+)
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_custom_protoc.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_custom_protoc.bzl
deleted file mode 100644
index b4157f6..0000000
--- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_custom_protoc.bzl
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright 2021 The Bazel Authors. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Defines a proto_lang_toolchain rule class with custom proto compiler.
-
-There are two physical rule classes for proto_lang_toolchain and we want both of them
-to have a name string of "proto_lang_toolchain".
-"""
-
-load(":common/proto/proto_lang_toolchain.bzl", "make_proto_lang_toolchain")
-
-proto_lang_toolchain = make_proto_lang_toolchain(custom_proto_compiler = True)
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_default_protoc.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_default_protoc.bzl
deleted file mode 100644
index 7345bc3..0000000
--- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_default_protoc.bzl
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright 2021 The Bazel Authors. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Defines a proto_lang_toolchain rule class with default proto compiler.
-
-There are two physical rule classes for proto_lang_toolchain and we want both of them
-to have a name string of "proto_lang_toolchain".
-"""
-
-load(":common/proto/proto_lang_toolchain.bzl", "make_proto_lang_toolchain")
-
-proto_lang_toolchain = make_proto_lang_toolchain(custom_proto_compiler = False)
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_wrapper.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_wrapper.bzl
deleted file mode 100644
index 075e6d7..0000000
--- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain_wrapper.bzl
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2021 The Bazel Authors. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro encapsulating the proto_lang_toolchain implementation.
-
-This is needed since proto compiler can be defined, or used as a default one.
-There are two implementations of proto_lang_toolchain - one with public proto_compiler attribute, and the other one with private compiler.
-"""
-
-load(":common/proto/proto_lang_toolchain_default_protoc.bzl", toolchain_default_protoc = "proto_lang_toolchain")
-load(":common/proto/proto_lang_toolchain_custom_protoc.bzl", toolchain_custom_protoc = "proto_lang_toolchain")
-
-def proto_lang_toolchain(
-        proto_compiler = None,
-        **kwargs):
-    if proto_compiler != None:
-        toolchain_custom_protoc(
-            proto_compiler = proto_compiler,
-            **kwargs
-        )
-    else:
-        toolchain_default_protoc(
-            **kwargs
-        )
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl
index 5ce86c3..d9943d7 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl
@@ -16,8 +16,8 @@
 Definition of proto_library rule.
 """
 
+load(":common/proto/proto_common.bzl", "toolchains", proto_common = "proto_common_do_not_use")
 load(":common/proto/proto_semantics.bzl", "semantics")
-load(":common/proto/proto_common.bzl", proto_common = "proto_common_do_not_use")
 load(":common/paths.bzl", "paths")
 
 ProtoInfo = _builtins.toplevel.ProtoInfo
@@ -251,15 +251,22 @@
             args.add("--allowed_public_imports=")
         else:
             args.add_joined("--allowed_public_imports", public_import_protos, map_each = _get_import_path, join_with = ":")
-    proto_lang_toolchain_info = proto_common.ProtoLangToolchainInfo(
-        out_replacement_format_flag = "--descriptor_set_out=%s",
-        output_files = "single",
-        mnemonic = "GenProtoDescriptorSet",
-        progress_message = "Generating Descriptor Set proto_library %{label}",
-        proto_compiler = ctx.executable._proto_compiler,
-        protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
-        plugin = None,
-    )
+    if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
+        toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN]
+        if not toolchain:
+            fail("Protocol compiler toolchain could not be resolved.")
+        proto_lang_toolchain_info = toolchain.proto
+    else:
+        proto_lang_toolchain_info = proto_common.ProtoLangToolchainInfo(
+            out_replacement_format_flag = "--descriptor_set_out=%s",
+            output_files = "single",
+            mnemonic = "GenProtoDescriptorSet",
+            progress_message = "Generating Descriptor Set proto_library %{label}",
+            proto_compiler = ctx.executable._proto_compiler,
+            protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
+            plugin = None,
+        )
+
     proto_common.compile(
         ctx.actions,
         proto_info,
@@ -271,7 +278,7 @@
 
 proto_library = rule(
     _proto_library_impl,
-    attrs = dict({
+    attrs = {
         "srcs": attr.label_list(
             allow_files = [".proto", ".protodevel"],
             flags = ["DIRECT_COMPILE_TIME_INPUT"],
@@ -288,14 +295,16 @@
             flags = ["SKIP_CONSTRAINTS_OVERRIDE"],
         ),
         "licenses": attr.license() if hasattr(attr, "license") else attr.string_list(),
+    } | toolchains.if_legacy_toolchain({
         "_proto_compiler": attr.label(
             cfg = "exec",
             executable = True,
             allow_files = True,
             default = configuration_field("proto", "proto_compiler"),
         ),
-    }, **semantics.EXTRA_ATTRIBUTES),
+    }) | semantics.EXTRA_ATTRIBUTES,
     fragments = ["proto"] + semantics.EXTRA_FRAGMENTS,
     provides = [ProtoInfo],
     output_to_genfiles = True,  # TODO(b/204266604) move to bin dir
+    toolchains = toolchains.use_toolchain(semantics.PROTO_TOOLCHAIN),
 )
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl
index 8634dd6..88ef131 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl
@@ -20,6 +20,7 @@
     pass
 
 semantics = struct(
+    PROTO_TOOLCHAIN = "@rules_proto//proto:toolchain_type",
     PROTO_COMPILER_LABEL = "@bazel_tools//tools/proto:protoc",
     EXTRA_ATTRIBUTES = {
         "import_prefix": attr.string(),
diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
index c78e740..6b6d717 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
@@ -16,6 +16,7 @@
 
 import com.google.devtools.build.lib.rules.proto.ProtoCommon;
 import com.google.devtools.build.lib.testutil.Scratch;
+import com.google.devtools.build.lib.rules.proto.ProtoConstants;
 import com.google.devtools.build.lib.testutil.TestConstants;
 import java.io.IOException;
 
@@ -39,12 +40,23 @@
    */
   public static void setup(MockToolsConfig config) throws IOException {
     createNetProto2(config);
+    setupWorkspace(config);
+    registerProtoToolchain(config);
   }
 
-  /**
-   * Create a dummy "net/proto2 compiler and proto APIs for all languages
-   * and versions.
-   */
+  private static void registerProtoToolchain(MockToolsConfig config) throws IOException {
+    config.append("WORKSPACE", "register_toolchains('tools/proto/toolchains:all')");
+    config.create(
+        "tools/proto/toolchains/BUILD",
+        TestConstants.LOAD_PROTO_TOOLCHAIN,
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
+        "proto_toolchain(name = 'protoc_sources',"
+            + "proto_compiler = '"
+            + ProtoConstants.DEFAULT_PROTOC_LABEL
+            + "')");
+  }
+
+  /** Create a dummy "net/proto2 compiler and proto APIs for all languages and versions. */
   private static void createNetProto2(MockToolsConfig config) throws IOException {
     config.create(
         "net/proto2/compiler/public/BUILD",
@@ -198,17 +210,24 @@
         "           srcs = [ 'metadata.go' ])");
   }
 
-  public static void setupWorkspace(Scratch scratch) throws Exception {
-    scratch.appendFile(
-        "WORKSPACE",
-        "local_repository(",
-        "    name = 'rules_proto',",
-        "    path = 'third_party/rules_proto',",
-        ")");
-    scratch.file("third_party/rules_proto/WORKSPACE");
-    scratch.file("third_party/rules_proto/proto/BUILD", "licenses(['notice'])");
-    scratch.file(
-        "third_party/rules_proto/proto/defs.bzl",
+  public static void setupWorkspace(MockToolsConfig config) throws IOException {
+    if (TestConstants.PRODUCT_NAME.equals("bazel")) {
+      config.append(
+          "WORKSPACE",
+          "local_repository(",
+          "    name = 'rules_proto',",
+          "    path = 'third_party/bazel_rules/rules_proto',",
+          ")");
+    }
+
+    config.create("third_party/bazel_rules/rules_proto/WORKSPACE");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/BUILD",
+        "licenses(['notice'])",
+        "toolchain_type(name = 'toolchain_type', visibility = ['//visibility:public'])");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/defs.bzl",
+        "load(':proto_lang_toolchain.bzl', _proto_lang_toolchain = 'proto_lang_toolchain')",
         "def _add_tags(kargs):",
         "    if 'tags' in kargs:",
         "        kargs['tags'] += ['__PROTO_RULES_MIGRATION_DO_NOT_USE_WILL_BREAK__']",
@@ -217,6 +236,71 @@
         "    return kargs",
         "",
         "def proto_library(**kargs): native.proto_library(**_add_tags(kargs))",
-        "def proto_lang_toolchain(**kargs): native.proto_lang_toolchain(**_add_tags(kargs))");
+        "def proto_lang_toolchain(**kargs): _proto_lang_toolchain(**_add_tags(kargs))");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/proto_toolchain.bzl",
+        "load(':proto_toolchain_rule.bzl', _proto_toolchain_rule = 'proto_toolchain')",
+        "def proto_toolchain(*, name, proto_compiler, exec_compatible_with = []):",
+        "  _proto_toolchain_rule(name = name, proto_compiler = proto_compiler)",
+        "  native.toolchain(",
+        "    name = name + '_toolchain',",
+        "    toolchain_type = '" + TestConstants.PROTO_TOOLCHAIN + "',",
+        "    exec_compatible_with = exec_compatible_with,",
+        "    target_compatible_with = [],",
+        "    toolchain = name,",
+        "  )");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/proto_toolchain_rule.bzl",
+        "ProtoLangToolchainInfo = proto_common_do_not_use.ProtoLangToolchainInfo",
+        "def _impl(ctx):",
+        "  return [",
+        "    DefaultInfo(",
+        "      files = depset(),",
+        "      runfiles = ctx.runfiles(),",
+        "    ),",
+        "    platform_common.ToolchainInfo(",
+        "      proto = ProtoLangToolchainInfo(",
+        "        out_replacement_format_flag = ctx.attr.command_line,",
+        "        output_files = ctx.attr.output_files,",
+        "        plugin = None,",
+        "        runtime = None,",
+        "        proto_compiler = ctx.attr.proto_compiler.files_to_run,",
+        "        protoc_opts = ctx.fragments.proto.experimental_protoc_opts,",
+        "        progress_message = ctx.attr.progress_message,",
+        "        mnemonic = ctx.attr.mnemonic,",
+        "        toolchain_type = '//third_party/bazel_rules/rules_proto/proto:toolchain_type'",
+        "      ),",
+        "    ),",
+        "  ]",
+        "proto_toolchain = rule(",
+        "  _impl,",
+        "  attrs = {",
+        "     'progress_message': attr.string(default = ",
+        "          'Generating Descriptor Set proto_library %{label}'),",
+        "    'mnemonic': attr.string(default = 'GenProtoDescriptorSet'),",
+        "    'command_line': attr.string(default = '--descriptor_set_out=%s'),",
+        "    'output_files': attr.string(values = ['single', 'multiple', 'legacy'],",
+        "       default = 'single'),",
+        "    'proto_compiler': attr.label(",
+        "       cfg = 'exec',",
+        "       executable = True,",
+        "       allow_files = True,",
+        "     ),",
+        "  },",
+        "  provides = [platform_common.ToolchainInfo],",
+        "  fragments = ['proto'],",
+        ")");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/proto_lang_toolchain.bzl",
+        "def proto_lang_toolchain(*, name, toolchain_type = None, exec_compatible_with = [],",
+        "         target_compatible_with = [], **attrs):",
+        "  native.proto_lang_toolchain(name = name, toolchain_type = toolchain_type, **attrs)",
+        "  if toolchain_type:",
+        "    native.toolchain(",
+        "      name = name + '_toolchain',",
+        "      toolchain_type = toolchain_type,",
+        "      exec_compatible_with = exec_compatible_with,",
+        "      target_compatible_with = target_compatible_with,",
+        "      toolchain = name)");
   }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/CcStarlarkApiProviderTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/CcStarlarkApiProviderTest.java
index 18401af..4c2d5c8 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/cpp/CcStarlarkApiProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/CcStarlarkApiProviderTest.java
@@ -41,7 +41,7 @@
 
   @Before
   public void setUp() throws Exception {
-    MockProtoSupport.setupWorkspace(scratch);
+    MockProtoSupport.setupWorkspace(mockToolsConfig);
     invalidatePackages();
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoLibraryTest.java
index 271b591..2511c8d 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoLibraryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoLibraryTest.java
@@ -55,6 +55,10 @@
 
   @Before
   public void setUp() throws Exception {
+    MockProtoSupport.setup(mockToolsConfig);
+    scratch.file(
+        "third_party/bazel_rules/rules_cc/cc/proto/BUILD",
+        "toolchain_type(name = 'toolchain_type', visibility = ['//visibility:public'])");
     scratch.file("protobuf/WORKSPACE");
     scratch.overwriteFile(
         "protobuf/BUILD",
@@ -78,11 +82,35 @@
         "    name = 'com_google_protobuf',",
         "    path = 'protobuf',",
         ")");
-    MockProtoSupport.setupWorkspace(scratch);
     invalidatePackages(); // A dash of magic to re-evaluate the WORKSPACE file.
   }
 
   @Test
+  public void protoToolchainResolution_enabled() throws Exception {
+    setBuildLanguageOptions("--incompatible_enable_proto_toolchain_resolution");
+    getAnalysisMock()
+        .ccSupport()
+        .setupCcToolchainConfig(
+            mockToolsConfig,
+            CcToolchainConfig.builder()
+                .withFeatures(
+                    CppRuleClasses.SUPPORTS_DYNAMIC_LINKER,
+                    CppRuleClasses.SUPPORTS_INTERFACE_SHARED_LIBRARIES));
+    scratch.file(
+        "x/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "cc_proto_library(name = 'foo_cc_proto', deps = ['foo_proto'])",
+        "proto_library(name = 'foo_proto', srcs = ['foo.proto'])");
+    assertThat(prettyArtifactNames(getFilesToBuild(getConfiguredTarget("//x:foo_cc_proto"))))
+        .containsExactly(
+            "x/foo.pb.h",
+            "x/foo.pb.cc",
+            "x/libfoo_proto.a",
+            "x/libfoo_proto.ifso",
+            "x/libfoo_proto.so");
+  }
+
+  @Test
   public void basic() throws Exception {
     getAnalysisMock()
         .ccSupport()
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/proto/BUILD b/src/test/java/com/google/devtools/build/lib/rules/java/proto/BUILD
index 530babc..c77b322 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/proto/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/proto/BUILD
@@ -27,6 +27,7 @@
         "//src/main/java/com/google/devtools/build/lib/rules/java:java-rules",
         "//src/test/java/com/google/devtools/build/lib/actions/util",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
+        "//src/test/java/com/google/devtools/build/lib/packages:testutil",
         "//src/test/java/com/google/devtools/build/lib/rules/java:java_compile_action_test_helper",
         "//src/test/java/com/google/devtools/build/lib/testutil:JunitUtils",
         "//third_party:guava",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
index 45cb965..1767b8c 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
@@ -35,6 +35,7 @@
 import com.google.devtools.build.lib.packages.Provider;
 import com.google.devtools.build.lib.packages.StarlarkProvider;
 import com.google.devtools.build.lib.packages.StructImpl;
+import com.google.devtools.build.lib.packages.util.MockProtoSupport;
 import com.google.devtools.build.lib.rules.java.JavaCompilationArgsProvider;
 import com.google.devtools.build.lib.rules.java.JavaCompileAction;
 import com.google.devtools.build.lib.rules.java.JavaInfo;
@@ -59,25 +60,20 @@
     useConfiguration(
         "--proto_compiler=//proto:compiler",
         "--proto_toolchain_for_javalite=//tools/proto/toolchains:javalite");
+    MockProtoSupport.setup(mockToolsConfig);
 
     scratch.file("proto/BUILD", "licenses(['notice'])", "exports_files(['compiler'])");
 
     mockToolchains();
+    invalidatePackages();
 
     actionsTestUtil = actionsTestUtil();
   }
 
-  @Before
-  public final void setupStarlarkRule() throws Exception {
-    setBuildLanguageOptions(
-        "--experimental_builtins_injection_override=+java_lite_proto_library",
-        "--experimental_google_legacy_api");
-  }
-
   private void mockToolchains() throws IOException {
     mockRuntimes();
 
-    scratch.file(
+    scratch.appendFile(
         "tools/proto/toolchains/BUILD",
         "package(default_visibility=['//visibility:public'])",
         "proto_lang_toolchain(",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/objc/J2ObjcLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/objc/J2ObjcLibraryTest.java
index 754e22f..91556b3 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/objc/J2ObjcLibraryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/objc/J2ObjcLibraryTest.java
@@ -76,7 +76,7 @@
 
     useConfiguration("--proto_toolchain_for_java=//tools/proto/toolchains:java");
 
-    mockToolsConfig.create(
+    mockToolsConfig.append(
         "tools/proto/toolchains/BUILD",
         TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
         "package(default_visibility=['//visibility:public'])",
@@ -86,7 +86,6 @@
         "proto_lang_toolchain(name='java_stubby_compatible13_immutable', "
             + "command_line = 'dont_care')");
 
-    MockProtoSupport.setupWorkspace(scratch);
     invalidatePackages();
   }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
index 39dc331..e4b1530 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
@@ -51,10 +51,8 @@
 
   @Before
   public final void setup() throws Exception {
-    MockProtoSupport.setupWorkspace(scratch);
-    invalidatePackages();
-
     MockProtoSupport.setup(mockToolsConfig);
+    invalidatePackages();
 
     scratch.file(
         "third_party/x/BUILD",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java
index 6fecb44..42e5e56 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java
@@ -31,6 +31,7 @@
 import com.google.devtools.build.lib.packages.util.MockProtoSupport;
 import com.google.devtools.build.lib.testutil.TestConstants;
 import com.google.devtools.build.lib.vfs.FileSystemUtils;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
 import java.util.List;
 import org.junit.Before;
 import org.junit.Ignore;
@@ -48,13 +49,26 @@
   @Before
   public void setUp() throws Exception {
     useConfiguration("--proto_compiler=//proto:compiler");
+    MockProtoSupport.setup(mockToolsConfig);
     scratch.file("proto/BUILD", "licenses(['notice'])", "exports_files(['compiler'])");
 
-    MockProtoSupport.setupWorkspace(scratch);
     invalidatePackages();
   }
 
   @Test
+  public void protoToolchainResolution_enabled() throws Exception {
+    setBuildLanguageOptions("--incompatible_enable_proto_toolchain_resolution");
+    scratch.file(
+        "x/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "proto_library(name='foo', srcs=['foo.proto'])");
+
+    getDescriptorOutput("//x:foo");
+
+    assertNoEvents();
+  }
+
+  @Test
   public void createsDescriptorSets() throws Exception {
     scratch.file(
         "x/BUILD",
@@ -1011,6 +1025,7 @@
         .contains("-Iy/z/q.proto=" + genfiles + "/external/foo/x/y/_virtual_imports/q/y/z/q.proto");
   }
 
+  @CanIgnoreReturnValue
   private Artifact getDescriptorOutput(String label) throws Exception {
     return getFirstArtifactEndingWith(getFilesToBuild(getConfiguredTarget(label)), ".proto.bin");
   }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoInfoStarlarkApiTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoInfoStarlarkApiTest.java
index 9c4c4bc..77796f7 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoInfoStarlarkApiTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoInfoStarlarkApiTest.java
@@ -42,8 +42,6 @@
     scratch.file("myinfo/myinfo.bzl", "MyInfo = provider()");
     scratch.file("myinfo/BUILD");
     MockProtoSupport.setup(mockToolsConfig);
-
-    MockProtoSupport.setupWorkspace(scratch);
     invalidatePackages();
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
index 23daa19..8374ed8 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
@@ -36,7 +36,6 @@
 public class ProtoLangToolchainTest extends BuildViewTestCase {
   @Before
   public void setUp() throws Exception {
-    MockProtoSupport.setupWorkspace(scratch);
     MockProtoSupport.setup(mockToolsConfig);
     useConfiguration("--protocopt=--myflag");
     invalidatePackages();
@@ -98,7 +97,8 @@
   }
 
   @Test
-  public void protoToolchain_setProtoCompiler() throws Exception {
+  public void protoToolchainResolution_enabled() throws Exception {
+    setBuildLanguageOptions("--incompatible_enable_proto_toolchain_resolution");
     scratch.file(
         "third_party/x/BUILD",
         "licenses(['unencumbered'])",
@@ -106,9 +106,7 @@
         "cc_library(name = 'runtime', srcs = ['runtime.cc'])",
         "filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
         "filegroup(name = 'any', srcs = ['any.proto'])",
-        "proto_library(name = 'denied', srcs = [':descriptors', ':any'])",
-        "cc_binary(name = 'compiler')");
-
+        "proto_library(name = 'denied', srcs = [':descriptors', ':any'])");
     scratch.file(
         "foo/BUILD",
         TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
@@ -121,14 +119,14 @@
         "    runtime = '//third_party/x:runtime',",
         "    progress_message = 'Progress Message %{label}',",
         "    mnemonic = 'MyMnemonic',",
-        "    proto_compiler = '//third_party/x:compiler',",
         ")");
 
+    update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());
     ProtoLangToolchainProvider toolchain =
         ProtoLangToolchainProvider.get(getConfiguredTarget("//foo:toolchain"));
 
     validateProtoLangToolchain(toolchain);
-    validateProtoCompiler(toolchain, "//third_party/x:compiler");
+    validateProtoCompiler(toolchain, ProtoConstants.DEFAULT_PROTOC_LABEL);
   }
 
   @Test
diff --git a/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java b/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
index c206106..6044384 100644
--- a/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
+++ b/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
@@ -26,6 +26,9 @@
 
   public static final String LOAD_PROTO_LIBRARY =
       "load('@rules_proto//proto:defs.bzl', 'proto_library')";
+  public static final String PROTO_TOOLCHAIN =  "@rules_proto//proto:toolchain_type";
+  public static final String LOAD_PROTO_TOOLCHAIN =
+      "load('@rules_proto//proto:proto_toolchain.bzl', 'proto_toolchain')";
   public static final String LOAD_PROTO_LANG_TOOLCHAIN =
       "load('@rules_proto//proto:defs.bzl', 'proto_lang_toolchain')";