Implement proto_common.experimental_should_generate_code.
Design doc: https://docs.google.com/document/d/1dY_jfRvnH8SjRXGIfg8av-vquyWsvIZydXJOywvaR1A/edit
PiperOrigin-RevId: 440109298
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
index aa75f71..f920330 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoCommon.java
@@ -38,11 +38,13 @@
import javax.annotation.Nullable;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Module;
+import net.starlark.java.eval.Sequence;
import net.starlark.java.eval.Starlark;
import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkFunction;
import net.starlark.java.eval.StarlarkList;
import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
/** Utility functions for proto_library and proto aspect implementations. */
public class ProtoCommon {
@@ -206,4 +208,51 @@
/* plugin_output */ pluginOutput == null ? Starlark.NONE : pluginOutput),
ImmutableMap.of("experimental_progress_message", progressMessage));
}
+
+ public static boolean shouldGenerateCode(
+ RuleContext ruleContext,
+ ConfiguredTarget protoTarget,
+ ProtoLangToolchainProvider protoLangToolchainInfo,
+ String ruleName)
+ throws RuleErrorException, InterruptedException {
+ StarlarkFunction shouldGenerateCode =
+ (StarlarkFunction)
+ ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_should_generate_code");
+ ruleContext.initStarlarkRuleContext();
+ return (Boolean)
+ ruleContext.callStarlarkOrThrowRuleError(
+ shouldGenerateCode,
+ ImmutableList.of(
+ /* proto_library_target */ protoTarget,
+ /* proto_lang_toolchain_info */ protoLangToolchainInfo,
+ /* rule_name */ ruleName),
+ ImmutableMap.of());
+ }
+
+ public static Sequence<Artifact> filterSources(
+ RuleContext ruleContext,
+ ConfiguredTarget protoTarget,
+ ProtoLangToolchainProvider protoLangToolchainInfo)
+ throws RuleErrorException, InterruptedException {
+ StarlarkFunction filterSources =
+ (StarlarkFunction)
+ ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_filter_sources");
+ ruleContext.initStarlarkRuleContext();
+ try {
+ return Sequence.cast(
+ ((Tuple)
+ ruleContext.callStarlarkOrThrowRuleError(
+ filterSources,
+ ImmutableList.of(
+ /* proto_library_target */ protoTarget,
+ /* proto_lang_toolchain_info */ protoLangToolchainInfo),
+ ImmutableMap.of()))
+ .get(0),
+ Artifact.class,
+ "included");
+ } catch (EvalException e) {
+
+ throw new RuleErrorException(e.getMessageWithStack());
+ }
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
index 3e090ac..d32c1bb 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java
@@ -113,6 +113,13 @@
return directProtoSources;
}
+ @Override
+ public ImmutableList<ProtoSource> getDirectProtoSourcesForStarlark(StarlarkThread thread)
+ throws EvalException {
+ ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
+ return directSources;
+ }
+
/**
* The source root of the current library.
*
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
index fe3ed09..dbe62d9 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoLangToolchainProvider.java
@@ -81,6 +81,10 @@
* Returns a list of {@link ProtoSource}s that are already provided by the protobuf runtime (i.e.
* for which {@code <lang>_proto_library} should not generate bindings.
*/
+ @StarlarkMethod(
+ name = "provided_proto_sources",
+ doc = "Proto sources provided by the toolchain.",
+ structField = true)
public abstract ImmutableList<ProtoSource> providedProtoSources();
@StarlarkMethod(name = "proto_compiler", doc = "Proto compiler.", structField = true)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
index 8b93079..27f13b4 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java
@@ -56,7 +56,12 @@
}
/** Returns the original source file. Only for forbidding protos! */
- @Deprecated
+ @StarlarkMethod(name = "original_source_file", documented = false, useStarlarkThread = true)
+ public Artifact getOriginalSourceFileForStarlark(StarlarkThread thread) throws EvalException {
+ ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
+ return originalSourceFile;
+ }
+
Artifact getOriginalSourceFile() {
return originalSourceFile;
}
diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
index d031958..a4e3fd5 100644
--- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
+++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java
@@ -64,6 +64,9 @@
structField = true)
ImmutableList<FileT> getDirectProtoSources();
+ @StarlarkMethod(name = "direct_proto_sources", documented = false, useStarlarkThread = true)
+ ImmutableList<?> getDirectProtoSourcesForStarlark(StarlarkThread thread) throws EvalException;
+
@StarlarkMethod(
name = "check_deps_sources",
doc =
diff --git a/src/main/starlark/builtins_bzl/common/exports.bzl b/src/main/starlark/builtins_bzl/common/exports.bzl
index c48c512..966a597 100755
--- a/src/main/starlark/builtins_bzl/common/exports.bzl
+++ b/src/main/starlark/builtins_bzl/common/exports.bzl
@@ -60,5 +60,7 @@
exported_to_java = {
"register_compile_and_archive_actions_for_j2objc": compilation_support.register_compile_and_archive_actions_for_j2objc,
"proto_common_compile": proto_common_do_not_use.compile,
+ "proto_common_experimental_should_generate_code": proto_common_do_not_use.experimental_should_generate_code,
+ "proto_common_experimental_filter_sources": proto_common_do_not_use.experimental_filter_sources,
"link_multi_arch_static_library": linking_support.link_multi_arch_static_library,
}
diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
index a04d15a..c3625b3 100644
--- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
+++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
@@ -167,11 +167,84 @@
resource_set = resource_set,
)
+_BAZEL_TOOLS_PREFIX = "external/bazel_tools/"
+
+def _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info):
+ proto_info = proto_library_target[_builtins.toplevel.ProtoInfo]
+ if not proto_info.direct_sources:
+ return [], []
+
+ # Collect a set of provided protos
+ provided_proto_sources = proto_lang_toolchain_info.provided_proto_sources
+ provided_paths = {}
+ for src in provided_proto_sources:
+ path = src.original_source_file().path
+
+ # For listed protos bundled with the Bazel tools repository, their exec paths start
+ # with external/bazel_tools/. This prefix needs to be removed first, because the protos in
+ # user repositories will not have that prefix.
+ if path.startswith(_BAZEL_TOOLS_PREFIX):
+ provided_paths[path[len(_BAZEL_TOOLS_PREFIX):]] = None
+ else:
+ provided_paths[path] = None
+
+ # Filter proto files
+ proto_files = [src.original_source_file() for src in proto_info.direct_proto_sources()]
+ excluded = []
+ included = []
+ for proto_file in proto_files:
+ if proto_file.path in provided_paths:
+ excluded.append(proto_file)
+ else:
+ included.append(proto_file)
+ return included, excluded
+
+def _experimental_should_generate_code(
+ proto_library_target,
+ proto_lang_toolchain_info,
+ rule_name):
+ """Checks if the code should be generated for the given proto_library.
+
+ The code shouldn't be generated only when the toolchain already provides it
+ to the language through its runtime dependency.
+
+ It fails when the proto_library contains mixed proto files, that should and
+ shouldn't generate code.
+
+ Args:
+ proto_library_target:
+ (Target) The proto_library to generate the sources for.
+ Obtained as the `target` parameter from an aspect's implementation.
+ proto_lang_toolchain_info:
+ (ProtoLangToolchainInfo) The proto lang toolchain info.
+ Obtained from a `proto_lang_toolchain` target or constructed ad-hoc.
+ rule_name: (str) Name of the rule used in the failure message.
+
+ Returns:
+ (bool) True when the code should be generated.
+ """
+ included, excluded = _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info)
+
+ if included and excluded:
+ fail(("The 'srcs' attribute of '%s' contains protos for which '%s' " +
+ "shouldn't generate code (%s), in addition to protos for which it should (%s).\n" +
+ "Separate '%s' into 2 proto_library rules.") % (
+ proto_library_target.label,
+ rule_name,
+ ", ".join([f.short_path for f in excluded]),
+ ", ".join([f.short_path for f in included]),
+ proto_library_target.label,
+ ))
+
+ return bool(included)
+
proto_common = struct(
create_proto_compile_action = _create_proto_compile_action,
)
proto_common_do_not_use = struct(
compile = _compile,
+ experimental_should_generate_code = _experimental_should_generate_code,
+ experimental_filter_sources = _experimental_filter_sources,
ProtoLangToolchainInfo = _builtins.internal.ProtoLangToolchainInfo,
)
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
index b831767..6ad4c0f 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
@@ -19,6 +19,8 @@
"//src/main/java/com/google/devtools/build/lib/actions:localhost_capacity",
"//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
"//src/main/java/com/google/devtools/build/lib/analysis:configured_target",
+ "//src/main/java/com/google/devtools/build/lib/cmdline",
+ "//src/main/java/com/google/devtools/build/lib/packages",
"//src/main/java/com/google/devtools/build/lib/util:os",
"//src/test/java/com/google/devtools/build/lib/actions/util",
"//src/test/java/com/google/devtools/build/lib/analysis/util",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
index 2f5bd52..e853d39 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommonTest.java
@@ -22,6 +22,10 @@
import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.actions.SpawnAction;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
+import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.packages.StarlarkInfo;
+import com.google.devtools.build.lib.packages.StarlarkProvider;
+import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
import com.google.devtools.build.lib.packages.util.MockProtoSupport;
import com.google.devtools.build.lib.testutil.TestConstants;
import com.google.devtools.build.lib.util.OS;
@@ -39,6 +43,11 @@
private static final Correspondence<String, String> MATCHES_REGEX =
Correspondence.from((a, b) -> Pattern.matches(b, a), "matches");
+ private static final StarlarkProviderIdentifier boolProviderId =
+ StarlarkProviderIdentifier.forKey(
+ new StarlarkProvider.Key(
+ Label.parseAbsoluteUnchecked("//foo:should_generate.bzl"), "BoolProvider"));
+
@Before
public final void setup() throws Exception {
MockProtoSupport.setupWorkspace(scratch);
@@ -53,6 +62,8 @@
"cc_library(name = 'runtime', srcs = ['runtime.cc'])",
"filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
"filegroup(name = 'any', srcs = ['any.proto'])",
+ "filegroup(name = 'something', srcs = ['something.proto'])",
+ "proto_library(name = 'mixed', srcs = [':descriptors', ':something'])",
"proto_library(name = 'denied', srcs = [':descriptors', ':any'])");
scratch.file(
"foo/BUILD",
@@ -115,6 +126,21 @@
" 'use_resource_set': attr.bool(),",
" 'progress_message': attr.string(),",
" })");
+
+ scratch.file(
+ "foo/should_generate.bzl",
+ "BoolProvider = provider()",
+ "def _impl(ctx):",
+ " result = proto_common_do_not_use.experimental_should_generate_code(",
+ " ctx.attr.proto_dep,",
+ " ctx.attr.toolchain[proto_common_do_not_use.ProtoLangToolchainInfo],",
+ " 'MyRule')",
+ " return [BoolProvider(value = result)]",
+ "should_generate_rule = rule(_impl,",
+ " attrs = {",
+ " 'proto_dep': attr.label(),",
+ " 'toolchain': attr.label(default = '//foo:toolchain'),",
+ " })");
}
/** Verifies basic usage of <code>proto_common.generate_code</code>. */
@@ -489,4 +515,55 @@
assertThat(spawnAction.getMnemonic()).isEqualTo("MyMnemonic");
assertThat(spawnAction.getProgressMessage()).isEqualTo("My //bar:simple");
}
+
+ /** Verifies <code>proto_common.should_generate_code</code> call. */
+ @Test
+ public void shouldGenerateCode_basic() throws Exception {
+ scratch.file(
+ "bar/BUILD",
+ TestConstants.LOAD_PROTO_LIBRARY,
+ "load('//foo:should_generate.bzl', 'should_generate_rule')",
+ "proto_library(name = 'proto', srcs = ['A.proto'])",
+ "should_generate_rule(name = 'simple', proto_dep = ':proto')");
+
+ ConfiguredTarget target = getConfiguredTarget("//bar:simple");
+
+ StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
+ assertThat(boolProvider.getValue("value", Boolean.class)).isTrue();
+ }
+
+ /** Verifies <code>proto_common.should_generate_code</code> call. */
+ @Test
+ public void shouldGenerateCode_dontGenerate() throws Exception {
+ scratch.file(
+ "bar/BUILD",
+ TestConstants.LOAD_PROTO_LIBRARY,
+ "load('//foo:should_generate.bzl', 'should_generate_rule')",
+ "should_generate_rule(name = 'simple', proto_dep = '//third_party/x:denied')");
+
+ ConfiguredTarget target = getConfiguredTarget("//bar:simple");
+
+ StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
+ assertThat(boolProvider.getValue("value", Boolean.class)).isFalse();
+ }
+
+ /** Verifies <code>proto_common.should_generate_code</code> call. */
+ @Test
+ public void shouldGenerateCode_mixed() throws Exception {
+ scratch.file(
+ "bar/BUILD",
+ TestConstants.LOAD_PROTO_LIBRARY,
+ "load('//foo:should_generate.bzl', 'should_generate_rule')",
+ "should_generate_rule(name = 'simple', proto_dep = '//third_party/x:mixed')");
+
+ reporter.removeHandler(failFastHandler);
+ getConfiguredTarget("//bar:simple");
+
+ assertContainsEvent(
+ "The 'srcs' attribute of '//third_party/x:mixed' contains protos for which 'MyRule'"
+ + " shouldn't generate code (third_party/x/metadata.proto,"
+ + " third_party/x/descriptor.proto), in addition to protos for which it should"
+ + " (third_party/x/something.proto).\n"
+ + "Separate '//third_party/x:mixed' into 2 proto_library rules.");
+ }
}