Custom init callback for providers
Implements https://github.com/bazelbuild/bazel/issues/14392
RELNOTES: provider() has a new parameter: init, a callback for performing
pre-processing and validation of field values. Iff this parameter is set,
provider() returns a tuple of 2 elements: the usual provider symbol (which,
when called, invokes init) and a raw constructor (which bypasses init).
PiperOrigin-RevId: 422702617
diff --git a/site/docs/skylark/rules.md b/site/docs/skylark/rules.md
index 86040e9..daedd16 100644
--- a/site/docs/skylark/rules.md
+++ b/site/docs/skylark/rules.md
@@ -570,6 +570,89 @@
]
```
+##### Custom initialization of providers
+
+It's possible to guard the instantiation of a provider with custom
+preprocessing and validation logic. This can be used to ensure that all
+provider instances obey certain invariants, or to give users a cleaner API for
+obtaining an instance.
+
+This is done by passing an `init` callback to the
+[`provider`](lib/globals.html#provider) function. If this callback is given, the
+return type of `provider()` changes to be a tuple of two values: the provider
+symbol that is the ordinary return value when `init` is not used, and a "raw
+constructor".
+
+In this case, when the provider symbol is called, instead of directly returning
+a new instance, it will forward the arguments along to the `init` callback. The
+callback's return value must be a dict mapping field names (strings) to values;
+this is used to initialize the fields of the new instance. Note that the
+callback may have any signature, and if the arguments do not match the signature
+an error is reported as if the callback were invoked directly.
+
+The raw constructor, by contrast, will bypass the `init` callback.
+
+The following example uses `init` to preprocess and validate its arguments:
+
+```python
+# //pkg:exampleinfo.bzl
+
+_core_headers = [...] # private constant representing standard library files
+
+# It's possible to define an init accepting positional arguments, but
+# keyword-only arguments are preferred.
+def _exampleinfo_init(*, files_to_link, headers = None, allow_empty_files_to_link = False):
+ if not files_to_link and not allow_empty_files_to_link:
+ fail("files_to_link may not be empty")
+ all_headers = depset(_core_headers, transitive = headers)
+ return {'files_to_link': files_to_link, 'headers': all_headers}
+
+ExampleInfo, _new_exampleinfo = provider(
+ ...
+ init = _exampleinfo_init)
+
+export ExampleInfo
+```
+
+A rule implementation may then instantiate the provider as follows:
+
+```python
+ ExampleInfo(
+ files_to_link=my_files_to_link, # may not be empty
+ headers = my_headers, # will automatically include the core headers
+ )
+```
+
+The raw constructor can be used to define alternative public factory functions
+that do not go through the `init` logic. For example, in exampleinfo.bzl we
+could define:
+
+```python
+def make_barebones_exampleinfo(headers):
+ """Returns an ExampleInfo with no files_to_link and only the specified headers."""
+ return _new_exampleinfo(files_to_link = depset(), headers = all_headers)
+```
+
+Typically, the raw constructor is bound to a variable whose name begins with an
+underscore (`_new_exampleinfo` above), so that user code cannot load it and
+generate arbitrary provider instances.
+
+Another use for `init` is to simply prevent the user from calling the provider
+symbol altogether, and force them to use a factory function instead:
+
+```python
+def _exampleinfo_init_banned(*args, **kwargs):
+ fail("Do not call ExampleInfo(). Use make_exampleinfo() instead.")
+
+ExampleInfo, _new_exampleinfo = provider(
+ ...
+ init = _exampleinfo_init_banned)
+
+def make_exampleinfo(...):
+ ...
+ return _new_exampleinfo(...)
+```
+
<a name="executable-rules"></a>
## Executable rules and test rules
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
index 9c38cce..07a3a45 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
@@ -77,7 +77,6 @@
import com.google.devtools.build.lib.packages.Package.NameConflictException;
import com.google.devtools.build.lib.packages.PackageFactory.PackageContext;
import com.google.devtools.build.lib.packages.PredicateWithMessage;
-import com.google.devtools.build.lib.packages.Provider;
import com.google.devtools.build.lib.packages.RuleClass;
import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassType;
import com.google.devtools.build.lib.packages.RuleClass.ToolchainTransitionMode;
@@ -272,14 +271,25 @@
}
@Override
- public Provider provider(String doc, Object fields, StarlarkThread thread) throws EvalException {
+ public Object provider(String doc, Object fields, Object init, StarlarkThread thread)
+ throws EvalException {
StarlarkProvider.Builder builder = StarlarkProvider.builder(thread.getCallerLocation());
if (fields instanceof Sequence) {
builder.setSchema(Sequence.cast(fields, String.class, "fields"));
} else if (fields instanceof Dict) {
builder.setSchema(Dict.cast(fields, String.class, String.class, "fields").keySet());
}
- return builder.build();
+ if (init == Starlark.NONE) {
+ return builder.build();
+ } else {
+ if (init instanceof StarlarkCallable) {
+ builder.setInit((StarlarkCallable) init);
+ } else {
+ throw Starlark.errorf("got %s for init, want callable value", Starlark.type(init));
+ }
+ StarlarkProvider provider = builder.build();
+ return Tuple.of(provider, provider.createRawConstructor());
+ }
}
// TODO(bazel-team): implement attribute copy and other rule properties
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
index 7a973de..70041f2 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
@@ -126,9 +126,9 @@
List<String> unexpected = unexpectedKeys(schema, table, n);
if (unexpected != null) {
throw Starlark.errorf(
- "unexpected keyword%s %s in call to instantiate provider %s",
+ "got unexpected field%s '%s' in call to instantiate provider %s",
unexpected.size() > 1 ? "s" : "",
- Joiner.on(", ").join(unexpected),
+ Joiner.on("', '").join(unexpected),
provider.getPrintableName());
}
}
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
index e0ad043..1dbe2d7 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
@@ -14,14 +14,17 @@
package com.google.devtools.build.lib.packages;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.events.EventHandler;
import com.google.devtools.build.lib.util.Fingerprint;
import java.util.Collection;
+import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
+import net.starlark.java.eval.Dict;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Printer;
import net.starlark.java.eval.Starlark;
@@ -39,6 +42,11 @@
* providers can have any set of fields on them, whereas instances of schemaful providers may have
* only the fields that are named in the schema.
*
+ * <p>{@code StarlarkProvider} may have a custom initializer callback, which might perform
+ * preprocessing or validation of field values. This callback (if defined) is automatically invoked
+ * when the provider is called. To create instances of the provider without calling the initializer
+ * callback, use the callable returned by {@code StarlarkProvider#createRawConstructor}.
+ *
* <p>Exporting a {@code StarlarkProvider} creates a key that is used to uniquely identify it.
* Usually a provider is exported by calling {@link #export}, but a test may wish to just create a
* pre-exported provider directly. Exported providers use only their key for {@link #equals} and
@@ -53,6 +61,11 @@
// as it lets us verify table ⊆ schema in O(n) time without temporaries.
@Nullable private final ImmutableList<String> schema;
+ // Optional custom initializer callback. If present, it is invoked with the same positional and
+ // keyword arguments as were passed to the provider constructor. The return value must be a
+ // Starlark dict mapping field names (string keys) to their values.
+ @Nullable private final StarlarkCallable init;
+
/** Null iff this provider has not yet been exported. Mutated by {@link export}. */
@Nullable private Key key;
@@ -78,6 +91,8 @@
@Nullable private ImmutableList<String> schema;
+ @Nullable private StarlarkCallable init;
+
@Nullable private Key key;
private Builder(Location location) {
@@ -93,6 +108,24 @@
return this;
}
+ /**
+ * Sets the custom initializer callback for instances of the provider built by this builder.
+ *
+ * <p>The initializer callback will be automatically invoked when the provider is called. To
+ * bypass the custom initializer callback, use the callable returned by {@link
+ * StarlarkProvider#createRawConstructor}.
+ *
+ * @param init A callback that accepts the arguments passed to the provider constructor, and
+ * which returns a dict mapping field names to their values. The resulting provider instance
+ * is created as though the dict were passed as **kwargs to the raw constructor. In
+ * particular, for a schemaful provider, the dict may not contain keys not listed in the
+ * schema.
+ */
+ public Builder setInit(StarlarkCallable init) {
+ this.init = init;
+ return this;
+ }
+
/** Sets the provider built by this builder to be exported with the given key. */
public Builder setExported(Key key) {
this.key = key;
@@ -101,25 +134,56 @@
/** Builds a StarlarkProvider. */
public StarlarkProvider build() {
- return new StarlarkProvider(location, schema, key);
+ return new StarlarkProvider(location, schema, init, key);
}
}
/**
* Constructs the provider.
*
- * <p>If {@code key} is null, the provider is unexported. If {@code schema} is null, the provider
- * is schemaless.
+ * <p>If {@code schema} is null, the provider is schemaless. If {@code init} is null, no custom
+ * initializer callback will be used (i.e., calling the provider is the same as simply calling the
+ * raw constructor). If {@code key} is null, the provider is unexported.
*/
private StarlarkProvider(
- Location location, @Nullable ImmutableList<String> schema, @Nullable Key key) {
+ Location location,
+ @Nullable ImmutableList<String> schema,
+ @Nullable StarlarkCallable init,
+ @Nullable Key key) {
this.location = location;
this.schema = schema;
+ this.init = init;
this.key = key;
}
+ private static Object[] toNamedArgs(Object value, String descriptionForError)
+ throws EvalException {
+ Dict<String, Object> kwargs = Dict.cast(value, String.class, Object.class, descriptionForError);
+ Object[] named = new Object[2 * kwargs.size()];
+ int i = 0;
+ for (Map.Entry<String, Object> e : kwargs.entrySet()) {
+ named[i++] = e.getKey();
+ named[i++] = e.getValue();
+ }
+ return named;
+ }
+
@Override
public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named)
+ throws InterruptedException, EvalException {
+ if (init == null) {
+ return fastcallRawConstructor(thread, positional, named);
+ }
+
+ Object initResult = Starlark.fastcall(thread, init, positional, named);
+ return StarlarkInfo.createFromNamedArgs(
+ this,
+ toNamedArgs(initResult, "return value of provider init()"),
+ schema,
+ thread.getCallerLocation());
+ }
+
+ private Object fastcallRawConstructor(StarlarkThread thread, Object[] positional, Object[] named)
throws EvalException {
if (positional.length > 0) {
throw Starlark.errorf("%s: unexpected positional arguments", getName());
@@ -127,6 +191,45 @@
return StarlarkInfo.createFromNamedArgs(this, named, schema, thread.getCallerLocation());
}
+ private static final class RawConstructor implements StarlarkCallable {
+ private final StarlarkProvider provider;
+
+ private RawConstructor(StarlarkProvider provider) {
+ this.provider = provider;
+ }
+
+ @Override
+ public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named)
+ throws EvalException {
+ return provider.fastcallRawConstructor(thread, positional, named);
+ }
+
+ @Override
+ public String getName() {
+ StringBuilder name = new StringBuilder("<raw constructor");
+ if (provider.isExported()) {
+ name.append(" for ").append(provider.getName());
+ }
+ name.append(">");
+ return name.toString();
+ }
+
+ @Override
+ public Location getLocation() {
+ return provider.location;
+ }
+ }
+
+ public StarlarkCallable createRawConstructor() {
+ return new RawConstructor(this);
+ }
+
+ @Nullable
+ @VisibleForTesting
+ public StarlarkCallable getInit() {
+ return init;
+ }
+
@Override
public Location getLocation() {
return location;
diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
index 7365caa..408892b 100644
--- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
+++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
@@ -19,7 +19,6 @@
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
import com.google.devtools.build.lib.starlarkbuildapi.StarlarkConfigApi.BuildSettingApi;
-import com.google.devtools.build.lib.starlarkbuildapi.core.ProviderApi;
import net.starlark.java.annot.Param;
import net.starlark.java.annot.ParamType;
import net.starlark.java.annot.StarlarkMethod;
@@ -55,13 +54,27 @@
@StarlarkMethod(
name = "provider",
doc =
- "Creates a declared provider 'constructor'. The return value of this "
- + "function can be used to create \"struct-like\" values. Example:<br>"
- + "<pre class=\"language-python\">data = provider()\n"
- + "d = data(x = 2, y = 3)\n"
- + "print(d.x + d.y) # prints 5</pre>"
- + "<p>See <a href='../rules.$DOC_EXT#providers'>Rules (Providers)</a> for a "
- + "comprehensive guide on how to use providers.",
+ "Defines a provider symbol. The provider may be instantiated by calling it, or used"
+ + " directly as a key for retrieving an instance of that provider from a target."
+ + " Example:<br><pre class=\"language-python\">" //
+ + "MyInfo = provider()\n"
+ + "...\n"
+ + "def _my_library_impl(ctx):\n"
+ + " ...\n"
+ + " my_info = MyInfo(x = 2, y = 3)\n"
+ + " # my_info.x == 2\n"
+ + " # my_info.y == 3\n"
+ + " ..." //
+ + "</pre><p>See <a href='../rules.$DOC_EXT#providers'>Rules (Providers)</a> for a "
+ + "comprehensive guide on how to use providers." //
+ + "<p>Returns a <a href='Provider.html#Provider'><code>Provider</code></a> callable "
+ + "value if <code>init</code> is not specified." //
+ + "<p>If <code>init</code> is specified, returns a tuple of 2 elements: a <a"
+ + " href='Provider.html#Provider'><code>Provider</code></a> callable value and a"
+ + " <em>raw constructor</em> callable value. See <a"
+ + " href='../rules.html#custom-initialization-of-providers'>Rules (Custom"
+ + " initialization of custom providers)</a> and the discussion of the"
+ + " <code>init</code> parameter below for details.",
parameters = {
@Param(
name = "doc",
@@ -87,10 +100,85 @@
},
named = true,
positional = false,
- defaultValue = "None")
+ defaultValue = "None"),
+ @Param(
+ name = "init",
+ doc =
+ "An optional callback for preprocessing and validating the provider's field values"
+ + " during instantiation. If <code>init</code> is specified,"
+ + " <code>provider()</code> returns a tuple of 2 elements: the normal provider"
+ + " symbol and a <em>raw constructor</em>." //
+ + "<p>A precise description follows; see <a"
+ + " href='../rules.html#custom-initialization-of-providers'>Rules (Custom"
+ + " initialization of providers)</a> for an intuitive discussion and use"
+ + " cases." //
+ + "<p>Let <code>P</code> be the provider symbol created by calling"
+ + " <code>provider()</code>. Conceptually, an instance of <code>P</code> is"
+ + " generated by calling a default constructor function <code>c(*args,"
+ + " **kwargs)</code>, which does the following:" //
+ + "<ul>" //
+ + "<li>If <code>args</code> is non-empty, an error occurs.</li>" //
+ + "<li>If the <code>fields</code> parameter was specified when"
+ + " <code>provider()</code> was called, and if <code>kwargs</code> contains any"
+ + " key that was not listed in <code>fields</code>, an error occurs.</li>" //
+ + "<li>Otherwise, <code>c</code> returns a new instance that has, for each"
+ + " <code>k: v</code> entry in <code>kwargs</code>, a field named"
+ + " <code>k</code> with value <code>v</code>." //
+ + "</ul>" //
+ + "In the case where an <code>init</code> callback is <em>not</em> given, a"
+ + " call to the symbol <code>P</code> itself acts as a call to the default"
+ + " constructor function <code>c</code>; in other words, <code>P(*args,"
+ + " **kwargs)</code> returns <code>c(*args, **kwargs)</code>. For example," //
+ + "<pre class=\"language-python\">" //
+ + "MyInfo = provider()\n" //
+ + "m = MyInfo(foo = 1)" //
+ + "</pre>" //
+ + "will straightforwardly make it so that <code>m</code> is a"
+ + " <code>MyInfo</code> instance with <code>m.foo == 1</code>." //
+ + "<p>But in the case where <code>init</code> is specified, the call"
+ + " <code>P(*args, **kwargs)</code> will perform the following steps"
+ + " instead:" //
+ + "<ol>" //
+ + "<li>The callback is invoked as <code>init(*args, **kwargs)</code>, that is,"
+ + " with the exact same positional and keyword arguments as were passed to"
+ + " <code>P</code>.</li>" //
+ + "<li>The return value of <code>init</code> is expected to be a dictionary,"
+ + " <code>d</code>, whose keys are field name strings. If it is not, an error"
+ + " occurs.</li>" //
+ + "<li>A new instance of <code>P</code> is generated as if by calling the"
+ + " default constructor with <code>d</code>'s entries as keyword arguments, as"
+ + " in <code>c(**d)</code>.</li>" //
+ + "</ol>" //
+ + "<p>NB: the above steps imply that an error occurs if <code>*args</code> or"
+ + " <code>**kwargs</code> does not match <code>init</code>'s signature, or the"
+ + " evaluation of <code>init</code>'s body fails (perhaps intentionally via a"
+ + " call to <a href=\"#fail\"><code>fail()</code></a>), or if the return value"
+ + " of <code>init</code> is not a dictionary with the expected schema." //
+ + "<p>In this way, the <code>init</code> callback generalizes normal provider"
+ + " construction by allowing positional arguments and arbitrary logic for"
+ + " preprocessing and validation. It does <em>not</em> enable circumventing the"
+ + " list of allowed <code>fields</code>." //
+ + "<p>When <code>init</code> is specified, the return value of"
+ + " <code>provider()</code> becomes a tuple <code>(P, r)</code>, where"
+ + " <code>r</code> is the <em>raw constructor</em>. In fact, the behavior of"
+ + " <code>r</code> is exactly that of the default constructor function"
+ + " <code>c</code> discussed above. Typically, <code>r</code> is bound to a"
+ + " variable whose name is prefixed with an underscore, so that only the"
+ + " current .bzl file has direct access to it:" //
+ + "<pre class=\"language-python\">" //
+ + "MyInfo, _new_myinfo = provider(init = ...)" //
+ + "</pre>",
+ named = true,
+ allowedTypes = {
+ @ParamType(type = StarlarkCallable.class),
+ @ParamType(type = NoneType.class),
+ },
+ positional = false,
+ defaultValue = "None"),
},
useStarlarkThread = true)
- ProviderApi provider(String doc, Object fields, StarlarkThread thread) throws EvalException;
+ Object provider(String doc, Object fields, Object init, StarlarkThread thread)
+ throws EvalException;
@StarlarkMethod(
name = "rule",
diff --git a/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java b/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
index 4a95d2d..f6e97f3 100644
--- a/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
+++ b/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
@@ -22,7 +22,6 @@
import com.google.devtools.build.lib.starlarkbuildapi.FileApi;
import com.google.devtools.build.lib.starlarkbuildapi.StarlarkAspectApi;
import com.google.devtools.build.lib.starlarkbuildapi.StarlarkRuleFunctionsApi;
-import com.google.devtools.build.lib.starlarkbuildapi.core.ProviderApi;
import com.google.devtools.build.skydoc.rendering.AspectInfoWrapper;
import com.google.devtools.build.skydoc.rendering.ProviderInfoWrapper;
import com.google.devtools.build.skydoc.rendering.RuleInfoWrapper;
@@ -45,6 +44,7 @@
import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkFunction;
import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
import net.starlark.java.syntax.Location;
/**
@@ -84,7 +84,7 @@
}
@Override
- public ProviderApi provider(String doc, Object fields, StarlarkThread thread)
+ public Object provider(String doc, Object fields, Object init, StarlarkThread thread)
throws EvalException {
FakeProviderApi fakeProvider = new FakeProviderApi(null);
// Field documentation will be output preserving the order in which the fields are listed.
@@ -102,7 +102,11 @@
// fields is NONE, so there is no field information to add.
}
providerInfoList.add(forProviderInfo(fakeProvider, doc, providerFieldInfos.build()));
- return fakeProvider;
+ if (init == Starlark.NONE) {
+ return fakeProvider;
+ } else {
+ return Tuple.of(fakeProvider, FakeDeepStructure.create("<raw constructor>"));
+ }
}
/** Constructor for ProviderFieldInfo. */
diff --git a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
index f176528..b88c0ee 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
@@ -16,16 +16,22 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.testing.EqualsTester;
import com.google.devtools.build.lib.cmdline.Label;
+import net.starlark.java.eval.Dict;
+import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Mutability;
import net.starlark.java.eval.Starlark;
+import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkInt;
import net.starlark.java.eval.StarlarkSemantics;
import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
import net.starlark.java.syntax.Location;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -36,28 +42,28 @@
public final class StarlarkProviderTest {
@Test
- public void unexportedProvider_Accessors() {
+ public void unexportedProvider_accessors() {
StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
assertThat(provider.isExported()).isFalse();
assertThat(provider.getName()).isEqualTo("<no name>");
assertThat(provider.getPrintableName()).isEqualTo("<no name>");
+ assertThat(provider.createRawConstructor().getName()).isEqualTo("<raw constructor>");
assertThat(provider.getErrorMessageForUnknownField("foo"))
.isEqualTo("'struct' value has no field or method 'foo'");
assertThat(provider.isImmutable()).isFalse();
assertThat(Starlark.repr(provider)).isEqualTo("<provider>");
- assertThrows(
- IllegalStateException.class,
- () -> provider.getKey());
+ assertThrows(IllegalStateException.class, provider::getKey);
}
@Test
- public void exportedProvider_Accessors() throws Exception {
+ public void exportedProvider_accessors() throws Exception {
StarlarkProvider.Key key =
new StarlarkProvider.Key(Label.parseAbsolute("//foo:bar.bzl", ImmutableMap.of()), "prov");
StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setExported(key).build();
assertThat(provider.isExported()).isTrue();
assertThat(provider.getName()).isEqualTo("prov");
assertThat(provider.getPrintableName()).isEqualTo("prov");
+ assertThat(provider.createRawConstructor().getName()).isEqualTo("<raw constructor for prov>");
assertThat(provider.getErrorMessageForUnknownField("foo"))
.isEqualTo("'prov' value has no field or method 'foo'");
assertThat(provider.isImmutable()).isTrue();
@@ -66,30 +72,166 @@
}
@Test
- public void schemalessProvider_Instantiation() throws Exception {
+ public void basicInstantiation() throws Exception {
StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
- StarlarkInfo info = instantiateWithA1B2C3(provider);
- assertHasExactlyValuesA1B2C3(info);
+ StarlarkInfo infoFromNormalConstructor = instantiateWithA1B2C3(provider);
+ assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+ assertThat(infoFromNormalConstructor.getProvider()).isEqualTo(provider);
+
+ StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+ assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+ assertThat(infoFromRawConstructor.getProvider()).isEqualTo(provider);
+
+ assertThat(infoFromNormalConstructor).isEqualTo(infoFromRawConstructor);
}
@Test
- public void schemafulProvider_Instantiation() throws Exception {
+ public void instantiationWithInit() throws Exception {
+ StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(initBC).build();
+ StarlarkInfo infoFromNormalConstructor = instantiateWithA1(provider);
+ assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+ assertThat(infoFromNormalConstructor.getProvider()).isEqualTo(provider);
+ }
+
+ @Test
+ public void instantiationWithInitSignatureMismatch_fails() throws Exception {
+ StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(initBC).build();
+ EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+ assertThat(e).hasMessageThat().contains("expected a single `a` argument");
+ }
+
+ @Test
+ public void instantiationWithInitReturnTypeMismatch_fails() throws Exception {
+ StarlarkCallable initWithInvalidReturnType =
+ new StarlarkCallable() {
+ @Override
+ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs) {
+ return "invalid";
+ }
+
+ @Override
+ public String getName() {
+ return "initWithInvalidReturnType";
+ }
+
+ @Override
+ public Location getLocation() {
+ return Location.BUILTIN;
+ }
+ };
+
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN).setInit(initWithInvalidReturnType).build();
+ EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+ assertThat(e)
+ .hasMessageThat()
+ .contains("got string for 'return value of provider init()', want dict");
+ }
+
+ @Test
+ public void instantiationWithFailingInit_fails() throws Exception {
+ StarlarkCallable failingInit =
+ new StarlarkCallable() {
+ @Override
+ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs)
+ throws EvalException {
+ throw Starlark.errorf("failingInit fails");
+ }
+
+ @Override
+ public String getName() {
+ return "failingInit";
+ }
+
+ @Override
+ public Location getLocation() {
+ return Location.BUILTIN;
+ }
+ };
+
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN).setInit(failingInit).build();
+ EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+ assertThat(e).hasMessageThat().contains("failingInit fails");
+ }
+
+ @Test
+ public void rawConstructorBypassesInit() throws Exception {
+ StarlarkCallable init = mock(StarlarkCallable.class, "init");
+ StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(init).build();
+ StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+ assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+ assertThat(infoFromRawConstructor.getProvider()).isEqualTo(provider);
+ verifyNoInteractions(init);
+ }
+
+ @Test
+ public void basicInstantiationWithSchemaWithSomeFieldsUnset() throws Exception {
StarlarkProvider provider =
StarlarkProvider.builder(Location.BUILTIN)
.setSchema(ImmutableList.of("a", "b", "c"))
.build();
- StarlarkInfo info = instantiateWithA1B2C3(provider);
- assertHasExactlyValuesA1B2C3(info);
+ StarlarkInfo infoFromNormalConstructor = instantiateWithA1(provider);
+ assertHasExactlyValuesA1(infoFromNormalConstructor);
+ StarlarkInfo infoFromRawConstructor = instantiateWithA1(provider.createRawConstructor());
+ assertHasExactlyValuesA1(infoFromRawConstructor);
}
@Test
- public void schemalessProvider_GetFields() throws Exception {
+ public void basicInstantiationWithSchemaWithAllFieldsSet() throws Exception {
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN)
+ .setSchema(ImmutableList.of("a", "b", "c"))
+ .build();
+ StarlarkInfo infoFromNormalConstructor = instantiateWithA1B2C3(provider);
+ assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+ StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+ assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+ }
+
+ @Test
+ public void schemaDisallowsUnexpectedFields() throws Exception {
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN).setSchema(ImmutableList.of("a", "b")).build();
+ EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+ assertThat(e)
+ .hasMessageThat()
+ .contains("got unexpected field 'c' in call to instantiate provider");
+ }
+
+ @Test
+ public void schemaEnforcedOnRawConstructor() throws Exception {
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN).setSchema(ImmutableList.of("a", "b")).build();
+ EvalException e =
+ assertThrows(
+ EvalException.class, () -> instantiateWithA1B2C3(provider.createRawConstructor()));
+ assertThat(e)
+ .hasMessageThat()
+ .contains("got unexpected field 'c' in call to instantiate provider");
+ }
+
+ @Test
+ public void schemaEnforcedOnInit() throws Exception {
+ StarlarkProvider provider =
+ StarlarkProvider.builder(Location.BUILTIN)
+ .setSchema(ImmutableList.of("a", "b"))
+ .setInit(initBC)
+ .build();
+ EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1(provider));
+ assertThat(e)
+ .hasMessageThat()
+ .contains("got unexpected field 'c' in call to instantiate provider");
+ }
+
+ @Test
+ public void schemalessProvider_getFields() throws Exception {
StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
assertThat(provider.getFields()).isNull();
}
@Test
- public void schemafulProvider_GetFields() throws Exception {
+ public void schemafulProvider_getFields() throws Exception {
StarlarkProvider provider =
StarlarkProvider.builder(Location.BUILTIN)
.setSchema(ImmutableList.of("a", "b", "c"))
@@ -128,15 +270,61 @@
new EqualsTester()
.addEqualityGroup(provFooA1, provFooA2)
.addEqualityGroup(provFooB)
- .addEqualityGroup(provBarA, provBarA) // reflexive equality (exported)
+ .addEqualityGroup(provBarA, provBarA) // reflexive equality (exported)
.addEqualityGroup(provBarB)
- .addEqualityGroup(provUnexported1, provUnexported1) // reflexive equality (unexported)
+ .addEqualityGroup(provUnexported1, provUnexported1) // reflexive equality (unexported)
.addEqualityGroup(provUnexported2)
.testEquals();
}
+ /** Custom init equivalent to `def initBC(a): return {a:a, b:a*2, c:a*3}` */
+ private static final StarlarkCallable initBC =
+ new StarlarkCallable() {
+ @Override
+ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs)
+ throws EvalException {
+ if (!args.isEmpty()) {
+ throw Starlark.errorf("unexpected positional arguments");
+ }
+ if (kwargs.size() != 1 || !kwargs.containsKey("a")) {
+ throw Starlark.errorf("expected a single `a` argument");
+ }
+ StarlarkInt a = (StarlarkInt) kwargs.get("a");
+ return Dict.builder()
+ .put("a", a)
+ .put("b", StarlarkInt.of(a.toIntUnchecked() * 2))
+ .put("c", StarlarkInt.of(a.toIntUnchecked() * 3))
+ .build(Mutability.IMMUTABLE);
+ }
+
+ @Override
+ public String getName() {
+ return "initBC";
+ }
+
+ @Override
+ public Location getLocation() {
+ return Location.BUILTIN;
+ }
+ };
+
+ /** Instantiates a {@link StarlarkInfo} with fields a=1 (and nothing else). */
+ private static StarlarkInfo instantiateWithA1(StarlarkCallable provider) throws Exception {
+ try (Mutability mu = Mutability.create()) {
+ StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+ Object result =
+ Starlark.call(
+ thread,
+ provider,
+ /*args=*/ ImmutableList.of(),
+ /*kwargs=*/ ImmutableMap.of("a", StarlarkInt.of(1)));
+ assertThat(result).isInstanceOf(StarlarkInfo.class);
+ return (StarlarkInfo) result;
+ }
+ }
+
/** Instantiates a {@link StarlarkInfo} with fields a=1, b=2, c=3 (and nothing else). */
- private static StarlarkInfo instantiateWithA1B2C3(StarlarkProvider provider) throws Exception {
+ private static StarlarkInfo instantiateWithA1B2C3(StarlarkCallable provider) throws Exception {
try (Mutability mu = Mutability.create()) {
StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
Object result =
@@ -151,6 +339,12 @@
}
}
+ /** Asserts that a {@link StarlarkInfo} has field a=1 (and nothing else). */
+ private static void assertHasExactlyValuesA1(StarlarkInfo info) throws Exception {
+ assertThat(info.getFieldNames()).containsExactly("a");
+ assertThat(info.getValue("a")).isEqualTo(StarlarkInt.of(1));
+ }
+
/** Asserts that a {@link StarlarkInfo} has fields a=1, b=2, c=3 (and nothing else). */
private static void assertHasExactlyValuesA1B2C3(StarlarkInfo info) throws Exception {
assertThat(info.getFieldNames()).containsExactly("a", "b", "c");
diff --git a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
index 5336000..5553db7 100644
--- a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
@@ -43,6 +43,7 @@
import com.google.devtools.build.lib.packages.ExecGroup;
import com.google.devtools.build.lib.packages.ImplicitOutputsFunction;
import com.google.devtools.build.lib.packages.PredicateWithMessage;
+import com.google.devtools.build.lib.packages.Provider;
import com.google.devtools.build.lib.packages.RequiredProviders;
import com.google.devtools.build.lib.packages.Rule;
import com.google.devtools.build.lib.packages.RuleClass;
@@ -68,6 +69,7 @@
import net.starlark.java.eval.Module;
import net.starlark.java.eval.Mutability;
import net.starlark.java.eval.Starlark;
+import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkInt;
import net.starlark.java.eval.StarlarkList;
import net.starlark.java.eval.Structure;
@@ -1684,6 +1686,117 @@
}
@Test
+ public void declaredProvidersWithInit() throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_init(x, y = 'abc'):", //
+ " return {'x': x, 'y': y}",
+ "data, _new_data = provider(init = _data_init)",
+ "d1 = data(x = 1) # normal provider constructor",
+ "d1_x = d1.x",
+ "d1_y = d1.y",
+ "d2 = data(1, 'def') # normal provider constructor invoked with positional arguments",
+ "d2_x = d2.x",
+ "d2_y = d2.y",
+ "d3 = _new_data(x = 2, y = 'xyz') # raw constructor",
+ "d3_x = d3.x",
+ "d3_y = d3.y");
+
+ assertThat(ev.lookup("d1_x")).isEqualTo(StarlarkInt.of(1));
+ assertThat(ev.lookup("d1_y")).isEqualTo("abc");
+ assertThat(ev.lookup("d2_x")).isEqualTo(StarlarkInt.of(1));
+ assertThat(ev.lookup("d2_y")).isEqualTo("def");
+ assertThat(ev.lookup("d3_x")).isEqualTo(StarlarkInt.of(2));
+ assertThat(ev.lookup("d3_y")).isEqualTo("xyz");
+ StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+ StarlarkCallable rawConstructor = (StarlarkCallable) ev.lookup("_new_data");
+ assertThat(rawConstructor).isNotInstanceOf(Provider.class);
+ assertThat(dataConstructor.getInit().getName()).isEqualTo("_data_init");
+
+ StructImpl data1 = (StructImpl) ev.lookup("d1");
+ StructImpl data2 = (StructImpl) ev.lookup("d2");
+ StructImpl data3 = (StructImpl) ev.lookup("d3");
+ assertThat(data1.getProvider()).isEqualTo(dataConstructor);
+ assertThat(data2.getProvider()).isEqualTo(dataConstructor);
+ assertThat(data3.getProvider()).isEqualTo(dataConstructor);
+ assertThat(dataConstructor.isExported()).isTrue();
+ assertThat(dataConstructor.getPrintableName()).isEqualTo("data");
+ assertThat(dataConstructor.getKey()).isEqualTo(new StarlarkProvider.Key(FAKE_LABEL, "data"));
+ }
+
+ @Test
+ public void declaredProvidersWithFailingInit_rawConstructorSucceeds() throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_failing_init(x):", //
+ " fail('_data_failing_init fails')",
+ "data, _new_data = provider(init = _data_failing_init)");
+
+ StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+
+ evalAndExport(ev, "d = _new_data(x = 1) # raw constructor");
+ StructImpl data = (StructImpl) ev.lookup("d");
+ assertThat(data.getProvider()).isEqualTo(dataConstructor);
+ }
+
+ @Test
+ public void declaredProvidersWithFailingInit_normalConstructorFails() throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_failing_init(x):", //
+ " fail('_data_failing_init fails')",
+ "data, _new_data = provider(init = _data_failing_init)");
+
+ ev.checkEvalErrorContains("_data_failing_init fails", "d = data(x = 1) # normal constructor");
+ assertThat(ev.lookup("d")).isNull();
+ }
+
+ @Test
+ public void declaredProvidersWithInitReturningInvalidType_normalConstructorFails()
+ throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_invalid_init(x):", //
+ " return 'INVALID'",
+ "data, _new_data = provider(init = _data_invalid_init)");
+
+ ev.checkEvalErrorContains(
+ "got string for 'return value of provider init()', want dict",
+ "d = data(x = 1) # normal constructor");
+ assertThat(ev.lookup("d")).isNull();
+ }
+
+ @Test
+ public void declaredProvidersWithInitReturningInvalidDict_normalConstructorFails()
+ throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_invalid_init(x):", //
+ " return {('x', 'x', 'x'): x}",
+ "data, _new_data = provider(init = _data_invalid_init)");
+
+ ev.checkEvalErrorContains(
+ "got dict<tuple, int> for 'return value of provider init()'",
+ "d = data(x = 1) # normal constructor");
+ assertThat(ev.lookup("d")).isNull();
+ }
+
+ @Test
+ public void declaredProvidersWithInitReturningUnexpectedFields_normalConstructorFails()
+ throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_unexpected_fields_init(x):", //
+ " return {'x': x, 'y': x * 2}",
+ "data, _new_data = provider(fields = ['x'], init = _data_unexpected_fields_init)");
+
+ ev.checkEvalErrorContains(
+ "got unexpected field 'y' in call to instantiate provider data",
+ "d = data(x = 1) # normal constructor");
+ assertThat(ev.lookup("d")).isNull();
+ }
+
+ @Test
public void declaredProvidersConcatSuccess() throws Exception {
evalAndExport(
ev,
@@ -1703,6 +1816,27 @@
}
@Test
+ public void declaredProvidersWithInitConcatSuccess() throws Exception {
+ evalAndExport(
+ ev,
+ "def _data_init(x):",
+ " return {'x': x}",
+ "data, _new_data = provider(init = _data_init)",
+ "dx = data(x = 1) # normal constructor",
+ "dy = _new_data(y = 'abc') # raw constructor",
+ "dxy = dx + dy",
+ "x = dxy.x",
+ "y = dxy.y");
+ assertThat(ev.lookup("x")).isEqualTo(StarlarkInt.of(1));
+ assertThat(ev.lookup("y")).isEqualTo("abc");
+ StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+ StructImpl dx = (StructImpl) ev.lookup("dx");
+ assertThat(dx.getProvider()).isEqualTo(dataConstructor);
+ StructImpl dy = (StructImpl) ev.lookup("dy");
+ assertThat(dy.getProvider()).isEqualTo(dataConstructor);
+ }
+
+ @Test
public void declaredProvidersConcatError() throws Exception {
evalAndExport(ev, "data1 = provider()", "data2 = provider()");
@@ -2098,7 +2232,7 @@
ev.setFailFast(false);
evalAndExport(ev, "p = provider(fields = ['x', 'y'])", "p1 = p(x = 1, y = 2, z = 3)");
MoreAsserts.assertContainsEvent(
- ev.getEventCollector(), "unexpected keyword z in call to instantiate provider p");
+ ev.getEventCollector(), "got unexpected field 'z' in call to instantiate provider p");
}
@Test
@@ -2109,7 +2243,8 @@
"p = provider(fields = [])", //
"p1 = p(x = 1, y = 2, z = 3)");
MoreAsserts.assertContainsEvent(
- ev.getEventCollector(), "unexpected keywords x, y, z in call to instantiate provider p");
+ ev.getEventCollector(),
+ "got unexpected fields 'x', 'y', 'z' in call to instantiate provider p");
}
@Test