Implement proto_common.experimental_should_generate_code.

Design doc: https://docs.google.com/document/d/1dY_jfRvnH8SjRXGIfg8av-vquyWsvIZydXJOywvaR1A/edit

PiperOrigin-RevId: 440109298
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
index aa75f71..f920330 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
@@ -38,11 +38,13 @@
 import javax.annotation.Nullable;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.Module;
+import net.starlark.java.eval.Sequence;
 import net.starlark.java.eval.Starlark;
 import net.starlark.java.eval.StarlarkCallable;
 import net.starlark.java.eval.StarlarkFunction;
 import net.starlark.java.eval.StarlarkList;
 import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
 
 /** Utility functions for proto_library and proto aspect implementations. */
 public class ProtoCommon {
@@ -206,4 +208,51 @@
             /* plugin_output */ pluginOutput == null ? Starlark.NONE : pluginOutput),
         ImmutableMap.of("experimental_progress_message", progressMessage));
   }
+
+  public static boolean shouldGenerateCode(
+      RuleContext ruleContext,
+      ConfiguredTarget protoTarget,
+      ProtoLangToolchainProvider protoLangToolchainInfo,
+      String ruleName)
+      throws RuleErrorException, InterruptedException {
+    StarlarkFunction shouldGenerateCode =
+        (StarlarkFunction)
+            ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_should_generate_code");
+    ruleContext.initStarlarkRuleContext();
+    return (Boolean)
+        ruleContext.callStarlarkOrThrowRuleError(
+            shouldGenerateCode,
+            ImmutableList.of(
+                /* proto_library_target */ protoTarget,
+                /* proto_lang_toolchain_info */ protoLangToolchainInfo,
+                /* rule_name */ ruleName),
+            ImmutableMap.of());
+  }
+
+  public static Sequence<Artifact> filterSources(
+      RuleContext ruleContext,
+      ConfiguredTarget protoTarget,
+      ProtoLangToolchainProvider protoLangToolchainInfo)
+      throws RuleErrorException, InterruptedException {
+    StarlarkFunction filterSources =
+        (StarlarkFunction)
+            ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_filter_sources");
+    ruleContext.initStarlarkRuleContext();
+    try {
+      return Sequence.cast(
+          ((Tuple)
+                  ruleContext.callStarlarkOrThrowRuleError(
+                      filterSources,
+                      ImmutableList.of(
+                          /* proto_library_target */ protoTarget,
+                          /* proto_lang_toolchain_info */ protoLangToolchainInfo),
+                      ImmutableMap.of()))
+              .get(0),
+          Artifact.class,
+          "included");
+    } catch (EvalException e) {
+
+      throw new RuleErrorException(e.getMessageWithStack());
+    }
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
index 3e090ac..d32c1bb 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
@@ -113,6 +113,13 @@
     return directProtoSources;
   }
 
+  @Override
+  public ImmutableList<ProtoSource> getDirectProtoSourcesForStarlark(StarlarkThread thread)
+      throws EvalException {
+    ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
+    return directSources;
+  }
+
   /**
    * The source root of the current library.
    *
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
index fe3ed09..dbe62d9 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
@@ -81,6 +81,10 @@
    * Returns a list of {@link ProtoSource}s that are already provided by the protobuf runtime (i.e.
    * for which {@code <lang>_proto_library} should not generate bindings.
    */
+  @StarlarkMethod(
+      name = "provided_proto_sources",
+      doc = "Proto sources provided by the toolchain.",
+      structField = true)
   public abstract ImmutableList<ProtoSource> providedProtoSources();
 
   @StarlarkMethod(name = "proto_compiler", doc = "Proto compiler.", structField = true)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
index 8b93079..27f13b4 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
@@ -56,7 +56,12 @@
   }
 
   /** Returns the original source file. Only for forbidding protos! */
-  @Deprecated
+  @StarlarkMethod(name = "original_source_file", documented = false, useStarlarkThread = true)
+  public Artifact getOriginalSourceFileForStarlark(StarlarkThread thread) throws EvalException {
+    ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
+    return originalSourceFile;
+  }
+
   Artifact getOriginalSourceFile() {
     return originalSourceFile;
   }
diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
index d031958..a4e3fd5 100644
--- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
+++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
@@ -64,6 +64,9 @@
       structField = true)
   ImmutableList<FileT> getDirectProtoSources();
 
+  @StarlarkMethod(name = "direct_proto_sources", documented = false, useStarlarkThread = true)
+  ImmutableList<?> getDirectProtoSourcesForStarlark(StarlarkThread thread) throws EvalException;
+
   @StarlarkMethod(
       name = "check_deps_sources",
       doc =
diff --git a/src/main/starlark/builtins_bzl/common/exports.bzl b/src/main/starlark/builtins_bzl/common/exports.bzl
index c48c512..966a597 100755
--- a/src/main/starlark/builtins_bzl/common/exports.bzl
+++ b/src/main/starlark/builtins_bzl/common/exports.bzl
@@ -60,5 +60,7 @@
 exported_to_java = {
     "register_compile_and_archive_actions_for_j2objc": compilation_support.register_compile_and_archive_actions_for_j2objc,
     "proto_common_compile": proto_common_do_not_use.compile,
+    "proto_common_experimental_should_generate_code": proto_common_do_not_use.experimental_should_generate_code,
+    "proto_common_experimental_filter_sources": proto_common_do_not_use.experimental_filter_sources,
     "link_multi_arch_static_library": linking_support.link_multi_arch_static_library,
 }
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
index a04d15a..c3625b3 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
@@ -167,11 +167,84 @@
         resource_set = resource_set,
     )
 
+_BAZEL_TOOLS_PREFIX = "external/bazel_tools/"
+
+def _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info):
+    proto_info = proto_library_target[_builtins.toplevel.ProtoInfo]
+    if not proto_info.direct_sources:
+        return [], []
+
+    # Collect a set of provided protos
+    provided_proto_sources = proto_lang_toolchain_info.provided_proto_sources
+    provided_paths = {}
+    for src in provided_proto_sources:
+        path = src.original_source_file().path
+
+        # For listed protos bundled with the Bazel tools repository, their exec paths start
+        # with external/bazel_tools/. This prefix needs to be removed first, because the protos in
+        # user repositories will not have that prefix.
+        if path.startswith(_BAZEL_TOOLS_PREFIX):
+            provided_paths[path[len(_BAZEL_TOOLS_PREFIX):]] = None
+        else:
+            provided_paths[path] = None
+
+    # Filter proto files
+    proto_files = [src.original_source_file() for src in proto_info.direct_proto_sources()]
+    excluded = []
+    included = []
+    for proto_file in proto_files:
+        if proto_file.path in provided_paths:
+            excluded.append(proto_file)
+        else:
+            included.append(proto_file)
+    return included, excluded
+
+def _experimental_should_generate_code(
+        proto_library_target,
+        proto_lang_toolchain_info,
+        rule_name):
+    """Checks if the code should be generated for the given proto_library.
+
+    The code shouldn't be generated only when the toolchain already provides it
+    to the language through its runtime dependency.
+
+    It fails when the proto_library contains mixed proto files, that should and
+    shouldn't generate code.
+
+    Args:
+      proto_library_target:
+        (Target) The proto_library to generate the sources for.
+        Obtained as the `target` parameter from an aspect's implementation.
+      proto_lang_toolchain_info:
+        (ProtoLangToolchainInfo) The proto lang toolchain info.
+        Obtained from a `proto_lang_toolchain` target or constructed ad-hoc.
+      rule_name: (str) Name of the rule used in the failure message.
+
+    Returns:
+      (bool) True when the code should be generated.
+    """
+    included, excluded = _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info)
+
+    if included and excluded:
+        fail(("The 'srcs' attribute of '%s' contains protos for which '%s' " +
+              "shouldn't generate code (%s), in addition to protos for which it should (%s).\n" +
+              "Separate '%s' into 2 proto_library rules.") % (
+            proto_library_target.label,
+            rule_name,
+            ", ".join([f.short_path for f in excluded]),
+            ", ".join([f.short_path for f in included]),
+            proto_library_target.label,
+        ))
+
+    return bool(included)
+
 proto_common = struct(
     create_proto_compile_action = _create_proto_compile_action,
 )
 
 proto_common_do_not_use = struct(
     compile = _compile,
+    experimental_should_generate_code = _experimental_should_generate_code,
+    experimental_filter_sources = _experimental_filter_sources,
     ProtoLangToolchainInfo = _builtins.internal.ProtoLangToolchainInfo,
 )
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
index b831767..6ad4c0f 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
@@ -19,6 +19,8 @@
         "//src/main/java/com/google/devtools/build/lib/actions:localhost_capacity",
         "//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
         "//src/main/java/com/google/devtools/build/lib/analysis:configured_target",
+        "//src/main/java/com/google/devtools/build/lib/cmdline",
+        "//src/main/java/com/google/devtools/build/lib/packages",
         "//src/main/java/com/google/devtools/build/lib/util:os",
         "//src/test/java/com/google/devtools/build/lib/actions/util",
         "//src/test/java/com/google/devtools/build/lib/analysis/util",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
index 2f5bd52..e853d39 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
@@ -22,6 +22,10 @@
 import com.google.devtools.build.lib.analysis.ConfiguredTarget;
 import com.google.devtools.build.lib.analysis.actions.SpawnAction;
 import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
+import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.packages.StarlarkInfo;
+import com.google.devtools.build.lib.packages.StarlarkProvider;
+import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
 import com.google.devtools.build.lib.packages.util.MockProtoSupport;
 import com.google.devtools.build.lib.testutil.TestConstants;
 import com.google.devtools.build.lib.util.OS;
@@ -39,6 +43,11 @@
   private static final Correspondence<String, String> MATCHES_REGEX =
       Correspondence.from((a, b) -> Pattern.matches(b, a), "matches");
 
+  private static final StarlarkProviderIdentifier boolProviderId =
+      StarlarkProviderIdentifier.forKey(
+          new StarlarkProvider.Key(
+              Label.parseAbsoluteUnchecked("//foo:should_generate.bzl"), "BoolProvider"));
+
   @Before
   public final void setup() throws Exception {
     MockProtoSupport.setupWorkspace(scratch);
@@ -53,6 +62,8 @@
         "cc_library(name = 'runtime', srcs = ['runtime.cc'])",
         "filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
         "filegroup(name = 'any', srcs = ['any.proto'])",
+        "filegroup(name = 'something', srcs = ['something.proto'])",
+        "proto_library(name = 'mixed', srcs = [':descriptors', ':something'])",
         "proto_library(name = 'denied', srcs = [':descriptors', ':any'])");
     scratch.file(
         "foo/BUILD",
@@ -115,6 +126,21 @@
         "     'use_resource_set': attr.bool(),",
         "     'progress_message': attr.string(),",
         "  })");
+
+    scratch.file(
+        "foo/should_generate.bzl",
+        "BoolProvider = provider()",
+        "def _impl(ctx):",
+        "  result = proto_common_do_not_use.experimental_should_generate_code(",
+        "    ctx.attr.proto_dep,",
+        "    ctx.attr.toolchain[proto_common_do_not_use.ProtoLangToolchainInfo],",
+        "    'MyRule')",
+        "  return [BoolProvider(value = result)]",
+        "should_generate_rule = rule(_impl,",
+        "  attrs = {",
+        "     'proto_dep': attr.label(),",
+        "     'toolchain': attr.label(default = '//foo:toolchain'),",
+        "  })");
   }
 
   /** Verifies basic usage of <code>proto_common.generate_code</code>. */
@@ -489,4 +515,55 @@
     assertThat(spawnAction.getMnemonic()).isEqualTo("MyMnemonic");
     assertThat(spawnAction.getProgressMessage()).isEqualTo("My //bar:simple");
   }
+
+  /** Verifies <code>proto_common.should_generate_code</code> call. */
+  @Test
+  public void shouldGenerateCode_basic() throws Exception {
+    scratch.file(
+        "bar/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "load('//foo:should_generate.bzl', 'should_generate_rule')",
+        "proto_library(name = 'proto', srcs = ['A.proto'])",
+        "should_generate_rule(name = 'simple', proto_dep = ':proto')");
+
+    ConfiguredTarget target = getConfiguredTarget("//bar:simple");
+
+    StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
+    assertThat(boolProvider.getValue("value", Boolean.class)).isTrue();
+  }
+
+  /** Verifies <code>proto_common.should_generate_code</code> call. */
+  @Test
+  public void shouldGenerateCode_dontGenerate() throws Exception {
+    scratch.file(
+        "bar/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "load('//foo:should_generate.bzl', 'should_generate_rule')",
+        "should_generate_rule(name = 'simple', proto_dep = '//third_party/x:denied')");
+
+    ConfiguredTarget target = getConfiguredTarget("//bar:simple");
+
+    StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
+    assertThat(boolProvider.getValue("value", Boolean.class)).isFalse();
+  }
+
+  /** Verifies <code>proto_common.should_generate_code</code> call. */
+  @Test
+  public void shouldGenerateCode_mixed() throws Exception {
+    scratch.file(
+        "bar/BUILD",
+        TestConstants.LOAD_PROTO_LIBRARY,
+        "load('//foo:should_generate.bzl', 'should_generate_rule')",
+        "should_generate_rule(name = 'simple', proto_dep = '//third_party/x:mixed')");
+
+    reporter.removeHandler(failFastHandler);
+    getConfiguredTarget("//bar:simple");
+
+    assertContainsEvent(
+        "The 'srcs' attribute of '//third_party/x:mixed' contains protos for which 'MyRule'"
+            + " shouldn't generate code (third_party/x/metadata.proto,"
+            + " third_party/x/descriptor.proto), in addition to protos for which it should"
+            + " (third_party/x/something.proto).\n"
+            + "Separate '//third_party/x:mixed' into 2 proto_library rules.");
+  }
 }