Use proto toolchains in java_proto_library

Retrieve proto_lang_toolchain using toolchains in java_proto_library when proto toolchain resolution is enabled.
Add mocks for proto_lang_toolchain from rules_proto.
Return ToolchainInfo from proto_lang_toolchain rule. That's needed in the resolution. Also returning ProtoLangToolchainInfo directly, to support legacy mechanism.

Issue: https://github.com/bazelbuild/rules_proto/issues/179
PiperOrigin-RevId: 570055332
Change-Id: Ieb0510d48778900b8576a71ddb20abfdcda220be
diff --git a/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl b/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
index 4d96aff..72577fc 100644
--- a/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
+++ b/src/main/starlark/builtins_bzl/common/java/java_semantics.bzl
@@ -73,4 +73,5 @@
     get_default_resource_path = _get_default_resource_path,
     compatible_javac_options = _compatible_javac_options,
     LAUNCHER_FLAG_LABEL = Label("@bazel_tools//tools/jdk:launcher_flag_alias"),
+    JAVA_PROTO_TOOLCHAIN = "@rules_java//java/proto:toolchain_type",
 )
diff --git a/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl b/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
index 540ba88..0261aef 100644
--- a/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
+++ b/src/main/starlark/builtins_bzl/common/java/proto/java_proto_library.bzl
@@ -15,7 +15,7 @@
 """The implementation of the `java_proto_library` rule and its aspect."""
 
 load(":common/java/java_semantics.bzl", "semantics")
-load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", proto_common = "proto_common_do_not_use")
+load(":common/proto/proto_common.bzl", "toolchains", proto_common = "proto_common_do_not_use")
 load(":common/proto/proto_info.bzl", "ProtoInfo")
 load(":common/java/java_info.bzl", "JavaInfo", _merge_private_for_builtins = "merge")
 load(
@@ -49,7 +49,7 @@
       runtime jars.
     """
 
-    proto_toolchain_info = ctx.attr._aspect_java_proto_toolchain[ProtoLangToolchainInfo]
+    proto_toolchain_info = toolchains.find_toolchain(ctx, "_aspect_java_proto_toolchain", semantics.JAVA_PROTO_TOOLCHAIN)
     source_jar = None
     if proto_common.experimental_should_generate_code(target[ProtoInfo], proto_toolchain_info, "java_proto_library", target.label):
         # Generate source jar using proto compiler.
@@ -131,12 +131,12 @@
 
 bazel_java_proto_aspect = aspect(
     implementation = _bazel_java_proto_aspect_impl,
-    attrs = {
+    attrs = toolchains.if_legacy_toolchain({
         "_aspect_java_proto_toolchain": attr.label(
             default = configuration_field(fragment = "proto", name = "proto_toolchain_for_java"),
         ),
-    },
-    toolchains = [semantics.JAVA_TOOLCHAIN],
+    }),
+    toolchains = [semantics.JAVA_TOOLCHAIN] + toolchains.use_toolchain(semantics.JAVA_PROTO_TOOLCHAIN),
     attr_aspects = ["deps", "exports"],
     required_providers = [ProtoInfo],
     provides = [JavaInfo, JavaProtoAspectInfo],
@@ -151,9 +151,10 @@
     Returns:
       ([JavaInfo, DefaultInfo, OutputGroupInfo])
     """
-    proto_toolchain = ctx.attr._aspect_java_proto_toolchain[ProtoLangToolchainInfo]
+    proto_toolchain = toolchains.find_toolchain(ctx, "_aspect_java_proto_toolchain", semantics.JAVA_PROTO_TOOLCHAIN)
     for dep in ctx.attr.deps:
         proto_common.check_collocated(ctx.label, dep[ProtoInfo], proto_toolchain)
+
     java_info = _merge_private_for_builtins([dep[JavaInfo] for dep in ctx.attr.deps], merge_java_outputs = False)
 
     transitive_src_and_runtime_jars = depset(transitive = [dep[JavaProtoAspectInfo].jars for dep in ctx.attr.deps])
@@ -174,9 +175,11 @@
         "deps": attr.label_list(providers = [ProtoInfo], aspects = [bazel_java_proto_aspect]),
         "licenses": attr.license() if hasattr(attr, "license") else attr.string_list(),
         "distribs": attr.string_list(),
+    } | toolchains.if_legacy_toolchain({
         "_aspect_java_proto_toolchain": attr.label(
             default = configuration_field(fragment = "proto", name = "proto_toolchain_for_java"),
         ),
-    },
+    }),
     provides = [JavaInfo],
+    toolchains = toolchains.use_toolchain(semantics.JAVA_PROTO_TOOLCHAIN),
 )
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
index 3e609ef..894c400 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
@@ -39,24 +39,24 @@
         proto_compiler = ctx.attr._proto_compiler.files_to_run
         protoc_opts = ctx.fragments.proto.experimental_protoc_opts
 
+    proto_lang_toolchain_info = ProtoLangToolchainInfo(
+        out_replacement_format_flag = flag,
+        output_files = ctx.attr.output_files,
+        plugin_format_flag = ctx.attr.plugin_format_flag,
+        plugin = plugin,
+        runtime = ctx.attr.runtime,
+        provided_proto_sources = provided_proto_sources,
+        proto_compiler = proto_compiler,
+        protoc_opts = protoc_opts,
+        progress_message = ctx.attr.progress_message,
+        mnemonic = ctx.attr.mnemonic,
+        allowlist_different_package = ctx.attr.allowlist_different_package,
+    )
     return [
-        DefaultInfo(
-            files = depset(),
-            runfiles = ctx.runfiles(),
-        ),
-        ProtoLangToolchainInfo(
-            out_replacement_format_flag = flag,
-            output_files = ctx.attr.output_files,
-            plugin_format_flag = ctx.attr.plugin_format_flag,
-            plugin = plugin,
-            runtime = ctx.attr.runtime,
-            provided_proto_sources = provided_proto_sources,
-            proto_compiler = proto_compiler,
-            protoc_opts = protoc_opts,
-            progress_message = ctx.attr.progress_message,
-            mnemonic = ctx.attr.mnemonic,
-            allowlist_different_package = ctx.attr.allowlist_different_package,
-        ),
+        DefaultInfo(files = depset(), runfiles = ctx.runfiles()),
+        _builtins.toplevel.platform_common.ToolchainInfo(proto = proto_lang_toolchain_info),
+        # TODO(b/300592942): remove when --incompatible_enable_proto_toolchains is flipped and removed
+        proto_lang_toolchain_info,
     ]
 
 proto_lang_toolchain = rule(
diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
index 71344ac..75d8d41 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java
@@ -40,9 +40,10 @@
 
   private static void registerProtoToolchain(MockToolsConfig config) throws IOException {
     config.append("WORKSPACE", "register_toolchains('tools/proto/toolchains:all')");
-    config.append(
+    config.create(
         "tools/proto/toolchains/BUILD",
         TestConstants.LOAD_PROTO_TOOLCHAIN,
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
         "proto_toolchain(name = 'protoc_sources',"
             + "proto_compiler = '"
             + ProtoConstants.DEFAULT_PROTOC_LABEL
@@ -220,6 +221,7 @@
         "toolchain_type(name = 'toolchain_type', visibility = ['//visibility:public'])");
     config.create(
         "third_party/bazel_rules/rules_proto/proto/defs.bzl",
+        "load(':proto_lang_toolchain.bzl', _proto_lang_toolchain = 'proto_lang_toolchain')",
         "def _add_tags(kargs):",
         "    if 'tags' in kargs:",
         "        kargs['tags'] += ['__PROTO_RULES_MIGRATION_DO_NOT_USE_WILL_BREAK__']",
@@ -228,7 +230,7 @@
         "    return kargs",
         "",
         "def proto_library(**kargs): native.proto_library(**_add_tags(kargs))",
-        "def proto_lang_toolchain(**kargs): native.proto_lang_toolchain(**_add_tags(kargs))");
+        "def proto_lang_toolchain(**kargs): _proto_lang_toolchain(**_add_tags(kargs))");
     config.create(
         "third_party/bazel_rules/rules_proto/proto/proto_toolchain.bzl",
         "load(':proto_toolchain_rule.bzl', _proto_toolchain_rule = 'proto_toolchain')",
@@ -282,5 +284,17 @@
         "  provides = [platform_common.ToolchainInfo],",
         "  fragments = ['proto'],",
         ")");
+    config.create(
+        "third_party/bazel_rules/rules_proto/proto/proto_lang_toolchain.bzl",
+        "def proto_lang_toolchain(*, name, toolchain_type = None, exec_compatible_with = [],",
+        "         target_compatible_with = [], **attrs):",
+        "  native.proto_lang_toolchain(name = name, **attrs)",
+        "  if toolchain_type:",
+        "    native.toolchain(",
+        "      name = name + '_toolchain',",
+        "      toolchain_type = toolchain_type,",
+        "      exec_compatible_with = exec_compatible_with,",
+        "      target_compatible_with = target_compatible_with,",
+        "      toolchain = name)");
   }
 }