Starlarkify proto_lang_toolchain and ProtoLangToolchainInfo provider

Added StarlarkProtoLangToolchainTest which uses starlarkified rule for verification.

I've deleted blacklisted_protos and forbidden_protos since they are not needed anymore.

PiperOrigin-RevId: 444223497
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
index dbe62d9..ed53f4d 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
@@ -16,11 +16,8 @@
 
 import com.google.auto.value.AutoValue;
 import com.google.common.collect.ImmutableList;
-import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.analysis.FilesToRunProvider;
 import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
-import com.google.devtools.build.lib.collect.nestedset.NestedSet;
-import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
 import com.google.devtools.build.lib.packages.BuiltinProvider;
 import com.google.devtools.build.lib.packages.NativeInfo;
 import javax.annotation.Nullable;
@@ -112,19 +109,6 @@
       structField = true)
   public abstract String mnemonic();
 
-  /**
-   * This makes the blacklisted_protos member available in the provider. It can be removed after
-   * users are migrated and a sufficient time for Bazel rules to migrate has elapsed.
-   */
-  @Deprecated
-  public NestedSet<Artifact> blacklistedProtos() {
-    return forbiddenProtos();
-  }
-
-  // TODO(yannic): Remove after migrating all users to `providedProtoSources()`.
-  @Deprecated
-  public abstract NestedSet<Artifact> forbiddenProtos();
-
   public static ProtoLangToolchainProvider create(
       String outReplacementFormatFlag,
       String pluginFormatFlag,
@@ -135,10 +119,6 @@
       ImmutableList<String> protocOpts,
       String progressMessage,
       String mnemonic) {
-    NestedSetBuilder<Artifact> blacklistedProtos = NestedSetBuilder.stableOrder();
-    for (ProtoSource protoSource : providedProtoSources) {
-      blacklistedProtos.add(protoSource.getOriginalSourceFile());
-    }
     return new AutoValue_ProtoLangToolchainProvider(
         outReplacementFormatFlag,
         pluginFormatFlag,
@@ -148,7 +128,6 @@
         protoc,
         protocOpts,
         progressMessage,
-        mnemonic,
-        blacklistedProtos.build());
+        mnemonic);
   }
 }
diff --git a/src/main/starlark/builtins_bzl/common/exports.bzl b/src/main/starlark/builtins_bzl/common/exports.bzl
index 442c6ae1..954036f 100755
--- a/src/main/starlark/builtins_bzl/common/exports.bzl
+++ b/src/main/starlark/builtins_bzl/common/exports.bzl
@@ -25,6 +25,7 @@
 load("@_builtins//:common/objc/linking_support.bzl", "linking_support")
 load("@_builtins//:common/proto/proto_common.bzl", "proto_common_do_not_use")
 load("@_builtins//:common/proto/proto_library.bzl", "proto_library")
+load("@_builtins//:common/proto/proto_lang_toolchain.bzl", "proto_lang_toolchain")
 load("@_builtins//:common/java/proto/java_lite_proto_library.bzl", "java_lite_proto_library")
 load("@_builtins//:common/cc/cc_library.bzl", "cc_library")
 
@@ -54,6 +55,7 @@
     "+cc_binary": cc_binary,
     "+cc_test": cc_test,
     "-cc_library": cc_library,
+    "-proto_lang_toolchain": proto_lang_toolchain,
 }
 
 # A list of Starlark functions callable from native rules implementation.
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
new file mode 100644
index 0000000..d8960b4
--- /dev/null
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl
@@ -0,0 +1,85 @@
+# Copyright 2021 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A Starlark implementation of the proto_lang_toolchain rule."""
+
+load(":common/proto/providers.bzl", "ProtoLangToolchainInfo")
+load(":common/proto/proto_semantics.bzl", "semantics")
+
+ProtoInfo = _builtins.toplevel.ProtoInfo
+proto_common = _builtins.toplevel.proto_common
+
+def _rule_impl(ctx):
+    provided_proto_sources = []
+    transitive_files = depset(transitive = [bp[ProtoInfo].transitive_sources for bp in ctx.attr.blacklisted_protos])
+    for file in transitive_files.to_list():
+        source_root = file.root.path
+        provided_proto_sources.append(proto_common.ProtoSource(file, file, source_root))
+
+    flag = ctx.attr.command_line
+    if flag.find("$(PLUGIN_OUT)") > -1:
+        fail("in attribute 'command_line': Placeholder '$(PLUGIN_OUT)' is not supported.")
+    flag = flag.replace("$(OUT)", "%s")
+
+    plugin = None
+    if ctx.attr.plugin != None:
+        plugin = ctx.attr.plugin[DefaultInfo].files_to_run
+
+    return [
+        DefaultInfo(
+            files = depset(),
+            runfiles = ctx.runfiles(),
+        ),
+        ProtoLangToolchainInfo(
+            out_replacement_format_flag = flag,
+            plugin_format_flag = ctx.attr.plugin_format_flag,
+            plugin = plugin,
+            runtime = ctx.attr.runtime,
+            provided_proto_sources = provided_proto_sources,
+            proto_compiler = ctx.attr._proto_compiler.files_to_run,
+            protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
+            progress_message = ctx.attr.progress_message,
+            mnemonic = ctx.attr.mnemonic,
+        ),
+    ]
+
+proto_lang_toolchain = rule(
+    implementation = _rule_impl,
+    attrs = {
+        "progress_message": attr.string(default = "Generating proto_library %{label}"),
+        "mnemonic": attr.string(default = "GenProto"),
+        "command_line": attr.string(mandatory = True),
+        "plugin_format_flag": attr.string(),
+        "plugin": attr.label(
+            executable = True,
+            cfg = "exec",
+            allow_files = True,
+        ),
+        "runtime": attr.label(
+            allow_files = True,
+        ),
+        "blacklisted_protos": attr.label_list(
+            allow_files = True,
+            providers = [ProtoInfo],
+        ),
+        "_proto_compiler": attr.label(
+            cfg = "exec",
+            executable = True,
+            allow_files = True,
+            default = configuration_field("proto", "proto_compiler"),
+        ),
+    },
+    provides = [ProtoLangToolchainInfo],
+    fragments = ["proto"] + semantics.EXTRA_FRAGMENTS,
+)
diff --git a/src/main/starlark/builtins_bzl/common/proto/providers.bzl b/src/main/starlark/builtins_bzl/common/proto/providers.bzl
new file mode 100644
index 0000000..aba5fe0
--- /dev/null
+++ b/src/main/starlark/builtins_bzl/common/proto/providers.bzl
@@ -0,0 +1,30 @@
+# Copyright 2021 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Bazel providers for proto rules."""
+
+ProtoLangToolchainInfo = provider(
+    doc = "Specifies how to generate language-specific code from .proto files. Used by LANG_proto_library rules.",
+    fields = dict(
+        out_replacement_format_flag = "(str) Format string used when passing output to the plugin used by proto compiler.",
+        plugin_format_flag = "(str) Format string used when passing plugin to proto compiler.",
+        plugin = "(FilesToRunProvider) Proto compiler plugin.",
+        runtime = "(Target) Runtime.",
+        provided_proto_sources = "(list[ProtoSource]) Proto sources provided by the toolchain.",
+        proto_compiler = "(FilesToRunProvider) Proto compiler.",
+        protoc_opts = "(list[str]) Options to pass to proto compiler.",
+        progress_message = "(str) Progress message to set on the proto compiler action.",
+        mnemonic = "(str) Mnemonic to set on the proto compiler action.",
+    ),
+)
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
index ec54d98..5602a5b 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
@@ -40,7 +40,6 @@
         "//src/main/java/com/google/devtools/build/lib/analysis:transitive_info_collection",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/rules/proto",
-        "//src/test/java/com/google/devtools/build/lib/actions/util",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
         "//src/test/java/com/google/devtools/build/lib/packages:testutil",
         "//src/test/java/com/google/devtools/build/lib/testutil:TestConstants",
@@ -51,6 +50,25 @@
 )
 
 java_test(
+    name = "StarlarkProtoLangToolchainTest",
+    srcs = ["StarlarkProtoLangToolchainTest.java"],
+    deps = [
+        ":ProtoLangToolchainTest",
+        "//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
+        "//src/main/java/com/google/devtools/build/lib/analysis:transitive_info_collection",
+        "//src/main/java/com/google/devtools/build/lib/cmdline",
+        "//src/main/java/com/google/devtools/build/lib/packages",
+        "//src/main/java/com/google/devtools/build/lib/rules/proto",
+        "//src/main/java/net/starlark/java/eval",
+        "//src/test/java/com/google/devtools/build/lib/analysis/util",
+        "//src/test/java/com/google/devtools/build/lib/testutil:TestConstants",
+        "//third_party:guava",
+        "//third_party:junit4",
+        "//third_party:truth",
+    ],
+)
+
+java_test(
     name = "BazelProtoLibraryTest",
     srcs = ["BazelProtoLibraryTest.java"],
     deps = [
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
index 96084c0..8126f30 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainTest.java
@@ -15,7 +15,6 @@
 package com.google.devtools.build.lib.rules.proto;
 
 import static com.google.common.truth.Truth.assertThat;
-import static com.google.devtools.build.lib.actions.util.ActionsTestUtil.prettyArtifactNames;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -30,8 +29,15 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+/** Unit tests for {@code proto_lang_toolchain}. */
 @RunWith(JUnit4.class)
 public class ProtoLangToolchainTest extends BuildViewTestCase {
+
+  @Before
+  public void setupStarlarkRule() throws Exception {
+    setBuildLanguageOptions("--experimental_builtins_injection_override=-proto_lang_toolchain");
+  }
+
   @Before
   public void setUp() throws Exception {
     MockProtoSupport.setupWorkspace(scratch);
@@ -50,12 +56,6 @@
     assertThat(runtimes.getLabel())
         .isEqualTo(Label.parseAbsolute("//third_party/x:runtime", ImmutableMap.of()));
 
-    assertThat(prettyArtifactNames(toolchain.forbiddenProtos()))
-        .containsExactly(
-            "third_party/x/metadata.proto",
-            "third_party/x/descriptor.proto",
-            "third_party/x/any.proto");
-
     assertThat(toolchain.protocOpts()).containsExactly("--myflag");
     Label protoc = Label.parseAbsoluteUnchecked(ProtoConstants.DEFAULT_PROTOC_LABEL);
     assertThat(toolchain.protoc().getExecutable().prettyPrint())
@@ -176,8 +176,6 @@
 
     assertThat(toolchain.pluginExecutable()).isNull();
     assertThat(toolchain.runtime()).isNull();
-    assertThat(toolchain.blacklistedProtos().toList()).isEmpty();
-    assertThat(toolchain.forbiddenProtos().toList()).isEmpty();
     assertThat(toolchain.mnemonic()).isEqualTo("GenProto");
   }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/StarlarkProtoLangToolchainTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/StarlarkProtoLangToolchainTest.java
new file mode 100644
index 0000000..de33c38
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/StarlarkProtoLangToolchainTest.java
@@ -0,0 +1,197 @@
+// Copyright 2016 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.rules.proto;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.eventbus.EventBus;
+import com.google.devtools.build.lib.analysis.FilesToRunProvider;
+import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
+import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.cmdline.LabelSyntaxException;
+import com.google.devtools.build.lib.packages.Provider;
+import com.google.devtools.build.lib.packages.StarlarkInfo;
+import com.google.devtools.build.lib.packages.StarlarkProvider;
+import com.google.devtools.build.lib.testutil.TestConstants;
+import net.starlark.java.eval.Starlark;
+import net.starlark.java.eval.StarlarkList;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@code proto_lang_toolchain}. */
+@RunWith(JUnit4.class)
+public class StarlarkProtoLangToolchainTest extends ProtoLangToolchainTest {
+
+  @Override
+  @Before
+  public void setupStarlarkRule() throws Exception {
+    setBuildLanguageOptions("--experimental_builtins_injection_override=+proto_lang_toolchain");
+  }
+
+  Provider.Key getStarlarkProtoLangToolchainInfoKey() throws LabelSyntaxException {
+    return new StarlarkProvider.Key(
+        Label.parseAbsolute("@_builtins//:common/proto/providers.bzl", ImmutableMap.of()),
+        "ProtoLangToolchainInfo");
+  }
+
+  @SuppressWarnings("unchecked")
+  private void validateStarlarkProtoLangToolchain(StarlarkInfo toolchain) throws Exception {
+    assertThat(toolchain.getValue("out_replacement_format_flag")).isEqualTo("cmd-line:%s");
+    assertThat(toolchain.getValue("plugin_format_flag")).isEqualTo("--plugin=%s");
+    assertThat(toolchain.getValue("progress_message")).isEqualTo("Progress Message %{label}");
+    assertThat(toolchain.getValue("mnemonic")).isEqualTo("MyMnemonic");
+    assertThat(ImmutableList.copyOf((StarlarkList<String>) toolchain.getValue("protoc_opts")))
+        .containsExactly("--myflag");
+    assertThat(
+            ((FilesToRunProvider) toolchain.getValue("plugin"))
+                .getExecutable()
+                .getRootRelativePathString())
+        .isEqualTo("third_party/x/plugin");
+
+    TransitiveInfoCollection runtimes = (TransitiveInfoCollection) toolchain.getValue("runtime");
+    assertThat(runtimes.getLabel())
+        .isEqualTo(Label.parseAbsolute("//third_party/x:runtime", ImmutableMap.of()));
+
+    Label protoc = Label.parseAbsoluteUnchecked(ProtoConstants.DEFAULT_PROTOC_LABEL);
+    assertThat(
+            ((FilesToRunProvider) toolchain.getValue("proto_compiler"))
+                .getExecutable()
+                .prettyPrint())
+        .isEqualTo(protoc.toPathFragment().getPathString());
+  }
+
+  @Override
+  @Test
+  public void protoToolchain() throws Exception {
+    scratch.file(
+        "third_party/x/BUILD",
+        "licenses(['unencumbered'])",
+        "cc_binary(name = 'plugin', srcs = ['plugin.cc'])",
+        "cc_library(name = 'runtime', srcs = ['runtime.cc'])",
+        "filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
+        "filegroup(name = 'any', srcs = ['any.proto'])",
+        "proto_library(name = 'denied', srcs = [':descriptors', ':any'])");
+
+    scratch.file(
+        "foo/BUILD",
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
+        "licenses(['unencumbered'])",
+        "proto_lang_toolchain(",
+        "    name = 'toolchain',",
+        "    command_line = 'cmd-line:$(OUT)',",
+        "    plugin_format_flag = '--plugin=%s',",
+        "    plugin = '//third_party/x:plugin',",
+        "    runtime = '//third_party/x:runtime',",
+        "    progress_message = 'Progress Message %{label}',",
+        "    mnemonic = 'MyMnemonic',",
+        ")");
+
+    update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());
+
+    validateStarlarkProtoLangToolchain(
+        (StarlarkInfo)
+            getConfiguredTarget("//foo:toolchain").get(getStarlarkProtoLangToolchainInfoKey()));
+  }
+
+  @Override
+  @Test
+  public void protoToolchainBlacklistProtoLibraries() throws Exception {
+    scratch.file(
+        "third_party/x/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "licenses(['unencumbered'])",
+        "cc_binary(name = 'plugin', srcs = ['plugin.cc'])",
+        "cc_library(name = 'runtime', srcs = ['runtime.cc'])",
+        "proto_library(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
+        "proto_library(name = 'any', srcs = ['any.proto'], strip_import_prefix = '/third_party')");
+
+    scratch.file(
+        "foo/BUILD",
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
+        "proto_lang_toolchain(",
+        "    name = 'toolchain',",
+        "    command_line = 'cmd-line:$(OUT)',",
+        "    plugin_format_flag = '--plugin=%s',",
+        "    plugin = '//third_party/x:plugin',",
+        "    runtime = '//third_party/x:runtime',",
+        "    progress_message = 'Progress Message %{label}',",
+        "    mnemonic = 'MyMnemonic',",
+        ")");
+
+    update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());
+
+    validateStarlarkProtoLangToolchain(
+        (StarlarkInfo)
+            getConfiguredTarget("//foo:toolchain").get(getStarlarkProtoLangToolchainInfoKey()));
+  }
+
+  @Override
+  @Test
+  public void protoToolchainBlacklistTransitiveProtos() throws Exception {
+    scratch.file(
+        "third_party/x/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "licenses(['unencumbered'])",
+        "cc_binary(name = 'plugin', srcs = ['plugin.cc'])",
+        "cc_library(name = 'runtime', srcs = ['runtime.cc'])",
+        "proto_library(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
+        "proto_library(name = 'any', srcs = ['any.proto'], deps = [':descriptors'])");
+
+    scratch.file(
+        "foo/BUILD",
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
+        "proto_lang_toolchain(",
+        "    name = 'toolchain',",
+        "    command_line = 'cmd-line:$(OUT)',",
+        "    plugin_format_flag = '--plugin=%s',",
+        "    plugin = '//third_party/x:plugin',",
+        "    runtime = '//third_party/x:runtime',",
+        "    progress_message = 'Progress Message %{label}',",
+        "    mnemonic = 'MyMnemonic',",
+        ")");
+
+    update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());
+
+    validateStarlarkProtoLangToolchain(
+        (StarlarkInfo)
+            getConfiguredTarget("//foo:toolchain").get(getStarlarkProtoLangToolchainInfoKey()));
+  }
+
+  @Override
+  @Test
+  public void optionalFieldsAreEmpty() throws Exception {
+    scratch.file(
+        "foo/BUILD",
+        TestConstants.LOAD_PROTO_LANG_TOOLCHAIN,
+        "proto_lang_toolchain(",
+        "    name = 'toolchain',",
+        "    command_line = 'cmd-line:$(OUT)',",
+        ")");
+
+    update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());
+
+    StarlarkInfo toolchain =
+        (StarlarkInfo)
+            getConfiguredTarget("//foo:toolchain").get(getStarlarkProtoLangToolchainInfoKey());
+
+    assertThat(toolchain.getValue("plugin")).isEqualTo(Starlark.NONE);
+    assertThat(toolchain.getValue("runtime")).isEqualTo(Starlark.NONE);
+    assertThat(toolchain.getValue("mnemonic")).isEqualTo("GenProto");
+  }
+}