Custom init callback for providers

Implements https://github.com/bazelbuild/bazel/issues/14392

RELNOTES: provider() has a new parameter: init, a callback for performing
pre-processing and validation of field values. Iff this parameter is set,
provider() returns a tuple of 2 elements: the usual provider symbol (which,
when called, invokes init) and a raw constructor (which bypasses init).
PiperOrigin-RevId: 422702617
diff --git a/site/docs/skylark/rules.md b/site/docs/skylark/rules.md
index 86040e9..daedd16 100644
--- a/site/docs/skylark/rules.md
+++ b/site/docs/skylark/rules.md
@@ -570,6 +570,89 @@
   ]
 ```
 
+##### Custom initialization of providers
+
+It's possible to guard the instantiation of a provider with custom
+preprocessing and validation logic. This can be used to ensure that all
+provider instances obey certain invariants, or to give users a cleaner API for
+obtaining an instance.
+
+This is done by passing an `init` callback to the
+[`provider`](lib/globals.html#provider) function. If this callback is given, the
+return type of `provider()` changes to be a tuple of two values: the provider
+symbol that is the ordinary return value when `init` is not used, and a "raw
+constructor".
+
+In this case, when the provider symbol is called, instead of directly returning
+a new instance, it will forward the arguments along to the `init` callback. The
+callback's return value must be a dict mapping field names (strings) to values;
+this is used to initialize the fields of the new instance. Note that the
+callback may have any signature, and if the arguments do not match the signature
+an error is reported as if the callback were invoked directly.
+
+The raw constructor, by contrast, will bypass the `init` callback.
+
+The following example uses `init` to preprocess and validate its arguments:
+
+```python
+# //pkg:exampleinfo.bzl
+
+_core_headers = [...]  # private constant representing standard library files
+
+# It's possible to define an init accepting positional arguments, but
+# keyword-only arguments are preferred.
+def _exampleinfo_init(*, files_to_link, headers = None, allow_empty_files_to_link = False):
+    if not files_to_link and not allow_empty_files_to_link:
+        fail("files_to_link may not be empty")
+    all_headers = depset(_core_headers, transitive = headers)
+    return {'files_to_link': files_to_link, 'headers': all_headers}
+
+ExampleInfo, _new_exampleinfo = provider(
+    ...
+    init = _exampleinfo_init)
+
+export ExampleInfo
+```
+
+A rule implementation may then instantiate the provider as follows:
+
+```python
+    ExampleInfo(
+        files_to_link=my_files_to_link,  # may not be empty
+        headers = my_headers,  # will automatically include the core headers
+    )
+```
+
+The raw constructor can be used to define alternative public factory functions
+that do not go through the `init` logic. For example, in exampleinfo.bzl we
+could define:
+
+```python
+def make_barebones_exampleinfo(headers):
+    """Returns an ExampleInfo with no files_to_link and only the specified headers."""
+    return _new_exampleinfo(files_to_link = depset(), headers = all_headers)
+```
+
+Typically, the raw constructor is bound to a variable whose name begins with an
+underscore (`_new_exampleinfo` above), so that user code cannot load it and
+generate arbitrary provider instances.
+
+Another use for `init` is to simply prevent the user from calling the provider
+symbol altogether, and force them to use a factory function instead:
+
+```python
+def _exampleinfo_init_banned(*args, **kwargs):
+    fail("Do not call ExampleInfo(). Use make_exampleinfo() instead.")
+
+ExampleInfo, _new_exampleinfo = provider(
+    ...
+    init = _exampleinfo_init_banned)
+
+def make_exampleinfo(...):
+    ...
+    return _new_exampleinfo(...)
+```
+
 <a name="executable-rules"></a>
 
 ## Executable rules and test rules
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
index 9c38cce..07a3a45 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/starlark/StarlarkRuleClassFunctions.java
@@ -77,7 +77,6 @@
 import com.google.devtools.build.lib.packages.Package.NameConflictException;
 import com.google.devtools.build.lib.packages.PackageFactory.PackageContext;
 import com.google.devtools.build.lib.packages.PredicateWithMessage;
-import com.google.devtools.build.lib.packages.Provider;
 import com.google.devtools.build.lib.packages.RuleClass;
 import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassType;
 import com.google.devtools.build.lib.packages.RuleClass.ToolchainTransitionMode;
@@ -272,14 +271,25 @@
   }
 
   @Override
-  public Provider provider(String doc, Object fields, StarlarkThread thread) throws EvalException {
+  public Object provider(String doc, Object fields, Object init, StarlarkThread thread)
+      throws EvalException {
     StarlarkProvider.Builder builder = StarlarkProvider.builder(thread.getCallerLocation());
     if (fields instanceof Sequence) {
       builder.setSchema(Sequence.cast(fields, String.class, "fields"));
     } else if (fields instanceof Dict) {
       builder.setSchema(Dict.cast(fields, String.class, String.class, "fields").keySet());
     }
-    return builder.build();
+    if (init == Starlark.NONE) {
+      return builder.build();
+    } else {
+      if (init instanceof StarlarkCallable) {
+        builder.setInit((StarlarkCallable) init);
+      } else {
+        throw Starlark.errorf("got %s for init, want callable value", Starlark.type(init));
+      }
+      StarlarkProvider provider = builder.build();
+      return Tuple.of(provider, provider.createRawConstructor());
+    }
   }
 
   // TODO(bazel-team): implement attribute copy and other rule properties
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
index 7a973de..70041f2 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkInfo.java
@@ -126,9 +126,9 @@
       List<String> unexpected = unexpectedKeys(schema, table, n);
       if (unexpected != null) {
         throw Starlark.errorf(
-            "unexpected keyword%s %s in call to instantiate provider %s",
+            "got unexpected field%s '%s' in call to instantiate provider %s",
             unexpected.size() > 1 ? "s" : "",
-            Joiner.on(", ").join(unexpected),
+            Joiner.on("', '").join(unexpected),
             provider.getPrintableName());
       }
     }
diff --git a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
index e0ad043..1dbe2d7 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/StarlarkProvider.java
@@ -14,14 +14,17 @@
 
 package com.google.devtools.build.lib.packages;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.util.Fingerprint;
 import java.util.Collection;
+import java.util.Map;
 import java.util.Objects;
 import javax.annotation.Nullable;
+import net.starlark.java.eval.Dict;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.Printer;
 import net.starlark.java.eval.Starlark;
@@ -39,6 +42,11 @@
  * providers can have any set of fields on them, whereas instances of schemaful providers may have
  * only the fields that are named in the schema.
  *
+ * <p>{@code StarlarkProvider} may have a custom initializer callback, which might perform
+ * preprocessing or validation of field values. This callback (if defined) is automatically invoked
+ * when the provider is called. To create instances of the provider without calling the initializer
+ * callback, use the callable returned by {@code StarlarkProvider#createRawConstructor}.
+ *
  * <p>Exporting a {@code StarlarkProvider} creates a key that is used to uniquely identify it.
  * Usually a provider is exported by calling {@link #export}, but a test may wish to just create a
  * pre-exported provider directly. Exported providers use only their key for {@link #equals} and
@@ -53,6 +61,11 @@
   // as it lets us verify table ⊆ schema in O(n) time without temporaries.
   @Nullable private final ImmutableList<String> schema;
 
+  // Optional custom initializer callback. If present, it is invoked with the same positional and
+  // keyword arguments as were passed to the provider constructor. The return value must be a
+  // Starlark dict mapping field names (string keys) to their values.
+  @Nullable private final StarlarkCallable init;
+
   /** Null iff this provider has not yet been exported. Mutated by {@link export}. */
   @Nullable private Key key;
 
@@ -78,6 +91,8 @@
 
     @Nullable private ImmutableList<String> schema;
 
+    @Nullable private StarlarkCallable init;
+
     @Nullable private Key key;
 
     private Builder(Location location) {
@@ -93,6 +108,24 @@
       return this;
     }
 
+    /**
+     * Sets the custom initializer callback for instances of the provider built by this builder.
+     *
+     * <p>The initializer callback will be automatically invoked when the provider is called. To
+     * bypass the custom initializer callback, use the callable returned by {@link
+     * StarlarkProvider#createRawConstructor}.
+     *
+     * @param init A callback that accepts the arguments passed to the provider constructor, and
+     *     which returns a dict mapping field names to their values. The resulting provider instance
+     *     is created as though the dict were passed as **kwargs to the raw constructor. In
+     *     particular, for a schemaful provider, the dict may not contain keys not listed in the
+     *     schema.
+     */
+    public Builder setInit(StarlarkCallable init) {
+      this.init = init;
+      return this;
+    }
+
     /** Sets the provider built by this builder to be exported with the given key. */
     public Builder setExported(Key key) {
       this.key = key;
@@ -101,25 +134,56 @@
 
     /** Builds a StarlarkProvider. */
     public StarlarkProvider build() {
-      return new StarlarkProvider(location, schema, key);
+      return new StarlarkProvider(location, schema, init, key);
     }
   }
 
   /**
    * Constructs the provider.
    *
-   * <p>If {@code key} is null, the provider is unexported. If {@code schema} is null, the provider
-   * is schemaless.
+   * <p>If {@code schema} is null, the provider is schemaless. If {@code init} is null, no custom
+   * initializer callback will be used (i.e., calling the provider is the same as simply calling the
+   * raw constructor). If {@code key} is null, the provider is unexported.
    */
   private StarlarkProvider(
-      Location location, @Nullable ImmutableList<String> schema, @Nullable Key key) {
+      Location location,
+      @Nullable ImmutableList<String> schema,
+      @Nullable StarlarkCallable init,
+      @Nullable Key key) {
     this.location = location;
     this.schema = schema;
+    this.init = init;
     this.key = key;
   }
 
+  private static Object[] toNamedArgs(Object value, String descriptionForError)
+      throws EvalException {
+    Dict<String, Object> kwargs = Dict.cast(value, String.class, Object.class, descriptionForError);
+    Object[] named = new Object[2 * kwargs.size()];
+    int i = 0;
+    for (Map.Entry<String, Object> e : kwargs.entrySet()) {
+      named[i++] = e.getKey();
+      named[i++] = e.getValue();
+    }
+    return named;
+  }
+
   @Override
   public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named)
+      throws InterruptedException, EvalException {
+    if (init == null) {
+      return fastcallRawConstructor(thread, positional, named);
+    }
+
+    Object initResult = Starlark.fastcall(thread, init, positional, named);
+    return StarlarkInfo.createFromNamedArgs(
+        this,
+        toNamedArgs(initResult, "return value of provider init()"),
+        schema,
+        thread.getCallerLocation());
+  }
+
+  private Object fastcallRawConstructor(StarlarkThread thread, Object[] positional, Object[] named)
       throws EvalException {
     if (positional.length > 0) {
       throw Starlark.errorf("%s: unexpected positional arguments", getName());
@@ -127,6 +191,45 @@
     return StarlarkInfo.createFromNamedArgs(this, named, schema, thread.getCallerLocation());
   }
 
+  private static final class RawConstructor implements StarlarkCallable {
+    private final StarlarkProvider provider;
+
+    private RawConstructor(StarlarkProvider provider) {
+      this.provider = provider;
+    }
+
+    @Override
+    public Object fastcall(StarlarkThread thread, Object[] positional, Object[] named)
+        throws EvalException {
+      return provider.fastcallRawConstructor(thread, positional, named);
+    }
+
+    @Override
+    public String getName() {
+      StringBuilder name = new StringBuilder("<raw constructor");
+      if (provider.isExported()) {
+        name.append(" for ").append(provider.getName());
+      }
+      name.append(">");
+      return name.toString();
+    }
+
+    @Override
+    public Location getLocation() {
+      return provider.location;
+    }
+  }
+
+  public StarlarkCallable createRawConstructor() {
+    return new RawConstructor(this);
+  }
+
+  @Nullable
+  @VisibleForTesting
+  public StarlarkCallable getInit() {
+    return init;
+  }
+
   @Override
   public Location getLocation() {
     return location;
diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
index 7365caa..408892b 100644
--- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
+++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/StarlarkRuleFunctionsApi.java
@@ -19,7 +19,6 @@
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
 import com.google.devtools.build.lib.starlarkbuildapi.StarlarkConfigApi.BuildSettingApi;
-import com.google.devtools.build.lib.starlarkbuildapi.core.ProviderApi;
 import net.starlark.java.annot.Param;
 import net.starlark.java.annot.ParamType;
 import net.starlark.java.annot.StarlarkMethod;
@@ -55,13 +54,27 @@
   @StarlarkMethod(
       name = "provider",
       doc =
-          "Creates a declared provider 'constructor'. The return value of this "
-              + "function can be used to create \"struct-like\" values. Example:<br>"
-              + "<pre class=\"language-python\">data = provider()\n"
-              + "d = data(x = 2, y = 3)\n"
-              + "print(d.x + d.y) # prints 5</pre>"
-              + "<p>See <a href='../rules.$DOC_EXT#providers'>Rules (Providers)</a> for a "
-              + "comprehensive guide on how to use providers.",
+          "Defines a provider symbol. The provider may be instantiated by calling it, or used"
+              + " directly as a key for retrieving an instance of that provider from a target."
+              + " Example:<br><pre class=\"language-python\">" //
+              + "MyInfo = provider()\n"
+              + "...\n"
+              + "def _my_library_impl(ctx):\n"
+              + "    ...\n"
+              + "    my_info = MyInfo(x = 2, y = 3)\n"
+              + "    # my_info.x == 2\n"
+              + "    # my_info.y == 3\n"
+              + "    ..." //
+              + "</pre><p>See <a href='../rules.$DOC_EXT#providers'>Rules (Providers)</a> for a "
+              + "comprehensive guide on how to use providers." //
+              + "<p>Returns a <a href='Provider.html#Provider'><code>Provider</code></a> callable "
+              + "value if <code>init</code> is not specified." //
+              + "<p>If <code>init</code> is specified, returns a tuple of 2 elements: a <a"
+              + " href='Provider.html#Provider'><code>Provider</code></a> callable value and a"
+              + " <em>raw constructor</em> callable value. See <a"
+              + " href='../rules.html#custom-initialization-of-providers'>Rules (Custom"
+              + " initialization of custom providers)</a> and the discussion of the"
+              + " <code>init</code> parameter below for details.",
       parameters = {
         @Param(
             name = "doc",
@@ -87,10 +100,85 @@
             },
             named = true,
             positional = false,
-            defaultValue = "None")
+            defaultValue = "None"),
+        @Param(
+            name = "init",
+            doc =
+                "An optional callback for preprocessing and validating the provider's field values"
+                    + " during instantiation. If <code>init</code> is specified,"
+                    + " <code>provider()</code> returns a tuple of 2 elements: the normal provider"
+                    + " symbol and a <em>raw constructor</em>." //
+                    + "<p>A precise description follows; see <a"
+                    + " href='../rules.html#custom-initialization-of-providers'>Rules (Custom"
+                    + " initialization of providers)</a> for an intuitive discussion and use"
+                    + " cases." //
+                    + "<p>Let <code>P</code> be the provider symbol created by calling"
+                    + " <code>provider()</code>. Conceptually, an instance of <code>P</code> is"
+                    + " generated by calling a default constructor function <code>c(*args,"
+                    + " **kwargs)</code>, which does the following:" //
+                    + "<ul>" //
+                    + "<li>If <code>args</code> is non-empty, an error occurs.</li>" //
+                    + "<li>If the <code>fields</code> parameter was specified when"
+                    + " <code>provider()</code> was called, and if <code>kwargs</code> contains any"
+                    + " key that was not listed in <code>fields</code>, an error occurs.</li>" //
+                    + "<li>Otherwise, <code>c</code> returns a new instance that has, for each"
+                    + " <code>k: v</code> entry in <code>kwargs</code>, a field named"
+                    + " <code>k</code> with value <code>v</code>." //
+                    + "</ul>" //
+                    + "In the case where an <code>init</code> callback is <em>not</em> given, a"
+                    + " call to the symbol <code>P</code> itself acts as a call to the default"
+                    + " constructor function <code>c</code>; in other words, <code>P(*args,"
+                    + " **kwargs)</code> returns <code>c(*args, **kwargs)</code>. For example," //
+                    + "<pre class=\"language-python\">" //
+                    + "MyInfo = provider()\n" //
+                    + "m = MyInfo(foo = 1)" //
+                    + "</pre>" //
+                    + "will straightforwardly make it so that <code>m</code> is a"
+                    + " <code>MyInfo</code> instance with <code>m.foo == 1</code>." //
+                    + "<p>But in the case where <code>init</code> is specified, the call"
+                    + " <code>P(*args, **kwargs)</code> will perform the following steps"
+                    + " instead:" //
+                    + "<ol>" //
+                    + "<li>The callback is invoked as <code>init(*args, **kwargs)</code>, that is,"
+                    + " with the exact same positional and keyword arguments as were passed to"
+                    + " <code>P</code>.</li>" //
+                    + "<li>The return value of <code>init</code> is expected to be a dictionary,"
+                    + " <code>d</code>, whose keys are field name strings. If it is not, an error"
+                    + " occurs.</li>" //
+                    + "<li>A new instance of <code>P</code> is generated as if by calling the"
+                    + " default constructor with <code>d</code>'s entries as keyword arguments, as"
+                    + " in <code>c(**d)</code>.</li>" //
+                    + "</ol>" //
+                    + "<p>NB: the above steps imply that an error occurs if <code>*args</code> or"
+                    + " <code>**kwargs</code> does not match <code>init</code>'s signature, or the"
+                    + " evaluation of <code>init</code>'s body fails (perhaps intentionally via a"
+                    + " call to <a href=\"#fail\"><code>fail()</code></a>), or if the return value"
+                    + " of <code>init</code> is not a dictionary with the expected schema." //
+                    + "<p>In this way, the <code>init</code> callback generalizes normal provider"
+                    + " construction by allowing positional arguments and arbitrary logic for"
+                    + " preprocessing and validation. It does <em>not</em> enable circumventing the"
+                    + " list of allowed <code>fields</code>." //
+                    + "<p>When <code>init</code> is specified, the return value of"
+                    + " <code>provider()</code> becomes a tuple <code>(P, r)</code>, where"
+                    + " <code>r</code> is the <em>raw constructor</em>. In fact, the behavior of"
+                    + " <code>r</code> is exactly that of the default constructor function"
+                    + " <code>c</code> discussed above. Typically, <code>r</code> is bound to a"
+                    + " variable whose name is prefixed with an underscore, so that only the"
+                    + " current .bzl file has direct access to it:" //
+                    + "<pre class=\"language-python\">" //
+                    + "MyInfo, _new_myinfo = provider(init = ...)" //
+                    + "</pre>",
+            named = true,
+            allowedTypes = {
+              @ParamType(type = StarlarkCallable.class),
+              @ParamType(type = NoneType.class),
+            },
+            positional = false,
+            defaultValue = "None"),
       },
       useStarlarkThread = true)
-  ProviderApi provider(String doc, Object fields, StarlarkThread thread) throws EvalException;
+  Object provider(String doc, Object fields, Object init, StarlarkThread thread)
+      throws EvalException;
 
   @StarlarkMethod(
       name = "rule",
diff --git a/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java b/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
index 4a95d2d..f6e97f3 100644
--- a/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
+++ b/src/main/java/com/google/devtools/build/skydoc/fakebuildapi/FakeStarlarkRuleFunctionsApi.java
@@ -22,7 +22,6 @@
 import com.google.devtools.build.lib.starlarkbuildapi.FileApi;
 import com.google.devtools.build.lib.starlarkbuildapi.StarlarkAspectApi;
 import com.google.devtools.build.lib.starlarkbuildapi.StarlarkRuleFunctionsApi;
-import com.google.devtools.build.lib.starlarkbuildapi.core.ProviderApi;
 import com.google.devtools.build.skydoc.rendering.AspectInfoWrapper;
 import com.google.devtools.build.skydoc.rendering.ProviderInfoWrapper;
 import com.google.devtools.build.skydoc.rendering.RuleInfoWrapper;
@@ -45,6 +44,7 @@
 import net.starlark.java.eval.StarlarkCallable;
 import net.starlark.java.eval.StarlarkFunction;
 import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
 import net.starlark.java.syntax.Location;
 
 /**
@@ -84,7 +84,7 @@
   }
 
   @Override
-  public ProviderApi provider(String doc, Object fields, StarlarkThread thread)
+  public Object provider(String doc, Object fields, Object init, StarlarkThread thread)
       throws EvalException {
     FakeProviderApi fakeProvider = new FakeProviderApi(null);
     // Field documentation will be output preserving the order in which the fields are listed.
@@ -102,7 +102,11 @@
       // fields is NONE, so there is no field information to add.
     }
     providerInfoList.add(forProviderInfo(fakeProvider, doc, providerFieldInfos.build()));
-    return fakeProvider;
+    if (init == Starlark.NONE) {
+      return fakeProvider;
+    } else {
+      return Tuple.of(fakeProvider, FakeDeepStructure.create("<raw constructor>"));
+    }
   }
 
   /** Constructor for ProviderFieldInfo. */
diff --git a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
index f176528..b88c0ee 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/StarlarkProviderTest.java
@@ -16,16 +16,22 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.testing.EqualsTester;
 import com.google.devtools.build.lib.cmdline.Label;
+import net.starlark.java.eval.Dict;
+import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.Mutability;
 import net.starlark.java.eval.Starlark;
+import net.starlark.java.eval.StarlarkCallable;
 import net.starlark.java.eval.StarlarkInt;
 import net.starlark.java.eval.StarlarkSemantics;
 import net.starlark.java.eval.StarlarkThread;
+import net.starlark.java.eval.Tuple;
 import net.starlark.java.syntax.Location;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -36,28 +42,28 @@
 public final class StarlarkProviderTest {
 
   @Test
-  public void unexportedProvider_Accessors() {
+  public void unexportedProvider_accessors() {
     StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
     assertThat(provider.isExported()).isFalse();
     assertThat(provider.getName()).isEqualTo("<no name>");
     assertThat(provider.getPrintableName()).isEqualTo("<no name>");
+    assertThat(provider.createRawConstructor().getName()).isEqualTo("<raw constructor>");
     assertThat(provider.getErrorMessageForUnknownField("foo"))
         .isEqualTo("'struct' value has no field or method 'foo'");
     assertThat(provider.isImmutable()).isFalse();
     assertThat(Starlark.repr(provider)).isEqualTo("<provider>");
-    assertThrows(
-        IllegalStateException.class,
-        () -> provider.getKey());
+    assertThrows(IllegalStateException.class, provider::getKey);
   }
 
   @Test
-  public void exportedProvider_Accessors() throws Exception {
+  public void exportedProvider_accessors() throws Exception {
     StarlarkProvider.Key key =
         new StarlarkProvider.Key(Label.parseAbsolute("//foo:bar.bzl", ImmutableMap.of()), "prov");
     StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setExported(key).build();
     assertThat(provider.isExported()).isTrue();
     assertThat(provider.getName()).isEqualTo("prov");
     assertThat(provider.getPrintableName()).isEqualTo("prov");
+    assertThat(provider.createRawConstructor().getName()).isEqualTo("<raw constructor for prov>");
     assertThat(provider.getErrorMessageForUnknownField("foo"))
         .isEqualTo("'prov' value has no field or method 'foo'");
     assertThat(provider.isImmutable()).isTrue();
@@ -66,30 +72,166 @@
   }
 
   @Test
-  public void schemalessProvider_Instantiation() throws Exception {
+  public void basicInstantiation() throws Exception {
     StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
-    StarlarkInfo info = instantiateWithA1B2C3(provider);
-    assertHasExactlyValuesA1B2C3(info);
+    StarlarkInfo infoFromNormalConstructor = instantiateWithA1B2C3(provider);
+    assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+    assertThat(infoFromNormalConstructor.getProvider()).isEqualTo(provider);
+
+    StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+    assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+    assertThat(infoFromRawConstructor.getProvider()).isEqualTo(provider);
+
+    assertThat(infoFromNormalConstructor).isEqualTo(infoFromRawConstructor);
   }
 
   @Test
-  public void schemafulProvider_Instantiation() throws Exception {
+  public void instantiationWithInit() throws Exception {
+    StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(initBC).build();
+    StarlarkInfo infoFromNormalConstructor = instantiateWithA1(provider);
+    assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+    assertThat(infoFromNormalConstructor.getProvider()).isEqualTo(provider);
+  }
+
+  @Test
+  public void instantiationWithInitSignatureMismatch_fails() throws Exception {
+    StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(initBC).build();
+    EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+    assertThat(e).hasMessageThat().contains("expected a single `a` argument");
+  }
+
+  @Test
+  public void instantiationWithInitReturnTypeMismatch_fails() throws Exception {
+    StarlarkCallable initWithInvalidReturnType =
+        new StarlarkCallable() {
+          @Override
+          public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs) {
+            return "invalid";
+          }
+
+          @Override
+          public String getName() {
+            return "initWithInvalidReturnType";
+          }
+
+          @Override
+          public Location getLocation() {
+            return Location.BUILTIN;
+          }
+        };
+
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN).setInit(initWithInvalidReturnType).build();
+    EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+    assertThat(e)
+        .hasMessageThat()
+        .contains("got string for 'return value of provider init()', want dict");
+  }
+
+  @Test
+  public void instantiationWithFailingInit_fails() throws Exception {
+    StarlarkCallable failingInit =
+        new StarlarkCallable() {
+          @Override
+          public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs)
+              throws EvalException {
+            throw Starlark.errorf("failingInit fails");
+          }
+
+          @Override
+          public String getName() {
+            return "failingInit";
+          }
+
+          @Override
+          public Location getLocation() {
+            return Location.BUILTIN;
+          }
+        };
+
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN).setInit(failingInit).build();
+    EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+    assertThat(e).hasMessageThat().contains("failingInit fails");
+  }
+
+  @Test
+  public void rawConstructorBypassesInit() throws Exception {
+    StarlarkCallable init = mock(StarlarkCallable.class, "init");
+    StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).setInit(init).build();
+    StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+    assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+    assertThat(infoFromRawConstructor.getProvider()).isEqualTo(provider);
+    verifyNoInteractions(init);
+  }
+
+  @Test
+  public void basicInstantiationWithSchemaWithSomeFieldsUnset() throws Exception {
     StarlarkProvider provider =
         StarlarkProvider.builder(Location.BUILTIN)
             .setSchema(ImmutableList.of("a", "b", "c"))
             .build();
-    StarlarkInfo info = instantiateWithA1B2C3(provider);
-    assertHasExactlyValuesA1B2C3(info);
+    StarlarkInfo infoFromNormalConstructor = instantiateWithA1(provider);
+    assertHasExactlyValuesA1(infoFromNormalConstructor);
+    StarlarkInfo infoFromRawConstructor = instantiateWithA1(provider.createRawConstructor());
+    assertHasExactlyValuesA1(infoFromRawConstructor);
   }
 
   @Test
-  public void schemalessProvider_GetFields() throws Exception {
+  public void basicInstantiationWithSchemaWithAllFieldsSet() throws Exception {
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a", "b", "c"))
+            .build();
+    StarlarkInfo infoFromNormalConstructor = instantiateWithA1B2C3(provider);
+    assertHasExactlyValuesA1B2C3(infoFromNormalConstructor);
+    StarlarkInfo infoFromRawConstructor = instantiateWithA1B2C3(provider.createRawConstructor());
+    assertHasExactlyValuesA1B2C3(infoFromRawConstructor);
+  }
+
+  @Test
+  public void schemaDisallowsUnexpectedFields() throws Exception {
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN).setSchema(ImmutableList.of("a", "b")).build();
+    EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1B2C3(provider));
+    assertThat(e)
+        .hasMessageThat()
+        .contains("got unexpected field 'c' in call to instantiate provider");
+  }
+
+  @Test
+  public void schemaEnforcedOnRawConstructor() throws Exception {
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN).setSchema(ImmutableList.of("a", "b")).build();
+    EvalException e =
+        assertThrows(
+            EvalException.class, () -> instantiateWithA1B2C3(provider.createRawConstructor()));
+    assertThat(e)
+        .hasMessageThat()
+        .contains("got unexpected field 'c' in call to instantiate provider");
+  }
+
+  @Test
+  public void schemaEnforcedOnInit() throws Exception {
+    StarlarkProvider provider =
+        StarlarkProvider.builder(Location.BUILTIN)
+            .setSchema(ImmutableList.of("a", "b"))
+            .setInit(initBC)
+            .build();
+    EvalException e = assertThrows(EvalException.class, () -> instantiateWithA1(provider));
+    assertThat(e)
+        .hasMessageThat()
+        .contains("got unexpected field 'c' in call to instantiate provider");
+  }
+
+  @Test
+  public void schemalessProvider_getFields() throws Exception {
     StarlarkProvider provider = StarlarkProvider.builder(Location.BUILTIN).build();
     assertThat(provider.getFields()).isNull();
   }
 
   @Test
-  public void schemafulProvider_GetFields() throws Exception {
+  public void schemafulProvider_getFields() throws Exception {
     StarlarkProvider provider =
         StarlarkProvider.builder(Location.BUILTIN)
             .setSchema(ImmutableList.of("a", "b", "c"))
@@ -128,15 +270,61 @@
     new EqualsTester()
         .addEqualityGroup(provFooA1, provFooA2)
         .addEqualityGroup(provFooB)
-        .addEqualityGroup(provBarA, provBarA)  // reflexive equality (exported)
+        .addEqualityGroup(provBarA, provBarA) // reflexive equality (exported)
         .addEqualityGroup(provBarB)
-        .addEqualityGroup(provUnexported1, provUnexported1)  // reflexive equality (unexported)
+        .addEqualityGroup(provUnexported1, provUnexported1) // reflexive equality (unexported)
         .addEqualityGroup(provUnexported2)
         .testEquals();
   }
 
+  /** Custom init equivalent to `def initBC(a): return {a:a, b:a*2, c:a*3}` */
+  private static final StarlarkCallable initBC =
+      new StarlarkCallable() {
+        @Override
+        public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwargs)
+            throws EvalException {
+          if (!args.isEmpty()) {
+            throw Starlark.errorf("unexpected positional arguments");
+          }
+          if (kwargs.size() != 1 || !kwargs.containsKey("a")) {
+            throw Starlark.errorf("expected a single `a` argument");
+          }
+          StarlarkInt a = (StarlarkInt) kwargs.get("a");
+          return Dict.builder()
+              .put("a", a)
+              .put("b", StarlarkInt.of(a.toIntUnchecked() * 2))
+              .put("c", StarlarkInt.of(a.toIntUnchecked() * 3))
+              .build(Mutability.IMMUTABLE);
+        }
+
+        @Override
+        public String getName() {
+          return "initBC";
+        }
+
+        @Override
+        public Location getLocation() {
+          return Location.BUILTIN;
+        }
+      };
+
+  /** Instantiates a {@link StarlarkInfo} with fields a=1 (and nothing else). */
+  private static StarlarkInfo instantiateWithA1(StarlarkCallable provider) throws Exception {
+    try (Mutability mu = Mutability.create()) {
+      StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
+      Object result =
+          Starlark.call(
+              thread,
+              provider,
+              /*args=*/ ImmutableList.of(),
+              /*kwargs=*/ ImmutableMap.of("a", StarlarkInt.of(1)));
+      assertThat(result).isInstanceOf(StarlarkInfo.class);
+      return (StarlarkInfo) result;
+    }
+  }
+
   /** Instantiates a {@link StarlarkInfo} with fields a=1, b=2, c=3 (and nothing else). */
-  private static StarlarkInfo instantiateWithA1B2C3(StarlarkProvider provider) throws Exception {
+  private static StarlarkInfo instantiateWithA1B2C3(StarlarkCallable provider) throws Exception {
     try (Mutability mu = Mutability.create()) {
       StarlarkThread thread = new StarlarkThread(mu, StarlarkSemantics.DEFAULT);
       Object result =
@@ -151,6 +339,12 @@
     }
   }
 
+  /** Asserts that a {@link StarlarkInfo} has field a=1 (and nothing else). */
+  private static void assertHasExactlyValuesA1(StarlarkInfo info) throws Exception {
+    assertThat(info.getFieldNames()).containsExactly("a");
+    assertThat(info.getValue("a")).isEqualTo(StarlarkInt.of(1));
+  }
+
   /** Asserts that a {@link StarlarkInfo} has fields a=1, b=2, c=3 (and nothing else). */
   private static void assertHasExactlyValuesA1B2C3(StarlarkInfo info) throws Exception {
     assertThat(info.getFieldNames()).containsExactly("a", "b", "c");
diff --git a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
index 5336000..5553db7 100644
--- a/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/starlark/StarlarkRuleClassFunctionsTest.java
@@ -43,6 +43,7 @@
 import com.google.devtools.build.lib.packages.ExecGroup;
 import com.google.devtools.build.lib.packages.ImplicitOutputsFunction;
 import com.google.devtools.build.lib.packages.PredicateWithMessage;
+import com.google.devtools.build.lib.packages.Provider;
 import com.google.devtools.build.lib.packages.RequiredProviders;
 import com.google.devtools.build.lib.packages.Rule;
 import com.google.devtools.build.lib.packages.RuleClass;
@@ -68,6 +69,7 @@
 import net.starlark.java.eval.Module;
 import net.starlark.java.eval.Mutability;
 import net.starlark.java.eval.Starlark;
+import net.starlark.java.eval.StarlarkCallable;
 import net.starlark.java.eval.StarlarkInt;
 import net.starlark.java.eval.StarlarkList;
 import net.starlark.java.eval.Structure;
@@ -1684,6 +1686,117 @@
   }
 
   @Test
+  public void declaredProvidersWithInit() throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_init(x, y = 'abc'):", //
+        "    return {'x': x, 'y': y}",
+        "data, _new_data = provider(init = _data_init)",
+        "d1 = data(x = 1)  # normal provider constructor",
+        "d1_x = d1.x",
+        "d1_y = d1.y",
+        "d2 = data(1, 'def')  # normal provider constructor invoked with positional arguments",
+        "d2_x = d2.x",
+        "d2_y = d2.y",
+        "d3 = _new_data(x = 2, y = 'xyz')  # raw constructor",
+        "d3_x = d3.x",
+        "d3_y = d3.y");
+
+    assertThat(ev.lookup("d1_x")).isEqualTo(StarlarkInt.of(1));
+    assertThat(ev.lookup("d1_y")).isEqualTo("abc");
+    assertThat(ev.lookup("d2_x")).isEqualTo(StarlarkInt.of(1));
+    assertThat(ev.lookup("d2_y")).isEqualTo("def");
+    assertThat(ev.lookup("d3_x")).isEqualTo(StarlarkInt.of(2));
+    assertThat(ev.lookup("d3_y")).isEqualTo("xyz");
+    StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+    StarlarkCallable rawConstructor = (StarlarkCallable) ev.lookup("_new_data");
+    assertThat(rawConstructor).isNotInstanceOf(Provider.class);
+    assertThat(dataConstructor.getInit().getName()).isEqualTo("_data_init");
+
+    StructImpl data1 = (StructImpl) ev.lookup("d1");
+    StructImpl data2 = (StructImpl) ev.lookup("d2");
+    StructImpl data3 = (StructImpl) ev.lookup("d3");
+    assertThat(data1.getProvider()).isEqualTo(dataConstructor);
+    assertThat(data2.getProvider()).isEqualTo(dataConstructor);
+    assertThat(data3.getProvider()).isEqualTo(dataConstructor);
+    assertThat(dataConstructor.isExported()).isTrue();
+    assertThat(dataConstructor.getPrintableName()).isEqualTo("data");
+    assertThat(dataConstructor.getKey()).isEqualTo(new StarlarkProvider.Key(FAKE_LABEL, "data"));
+  }
+
+  @Test
+  public void declaredProvidersWithFailingInit_rawConstructorSucceeds() throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_failing_init(x):", //
+        "    fail('_data_failing_init fails')",
+        "data, _new_data = provider(init = _data_failing_init)");
+
+    StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+
+    evalAndExport(ev, "d = _new_data(x = 1)  # raw constructor");
+    StructImpl data = (StructImpl) ev.lookup("d");
+    assertThat(data.getProvider()).isEqualTo(dataConstructor);
+  }
+
+  @Test
+  public void declaredProvidersWithFailingInit_normalConstructorFails() throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_failing_init(x):", //
+        "    fail('_data_failing_init fails')",
+        "data, _new_data = provider(init = _data_failing_init)");
+
+    ev.checkEvalErrorContains("_data_failing_init fails", "d = data(x = 1)  # normal constructor");
+    assertThat(ev.lookup("d")).isNull();
+  }
+
+  @Test
+  public void declaredProvidersWithInitReturningInvalidType_normalConstructorFails()
+      throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_invalid_init(x):", //
+        "    return 'INVALID'",
+        "data, _new_data = provider(init = _data_invalid_init)");
+
+    ev.checkEvalErrorContains(
+        "got string for 'return value of provider init()', want dict",
+        "d = data(x = 1)  # normal constructor");
+    assertThat(ev.lookup("d")).isNull();
+  }
+
+  @Test
+  public void declaredProvidersWithInitReturningInvalidDict_normalConstructorFails()
+      throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_invalid_init(x):", //
+        "    return {('x', 'x', 'x'): x}",
+        "data, _new_data = provider(init = _data_invalid_init)");
+
+    ev.checkEvalErrorContains(
+        "got dict<tuple, int> for 'return value of provider init()'",
+        "d = data(x = 1)  # normal constructor");
+    assertThat(ev.lookup("d")).isNull();
+  }
+
+  @Test
+  public void declaredProvidersWithInitReturningUnexpectedFields_normalConstructorFails()
+      throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_unexpected_fields_init(x):", //
+        "    return {'x': x, 'y': x * 2}",
+        "data, _new_data = provider(fields = ['x'], init = _data_unexpected_fields_init)");
+
+    ev.checkEvalErrorContains(
+        "got unexpected field 'y' in call to instantiate provider data",
+        "d = data(x = 1)  # normal constructor");
+    assertThat(ev.lookup("d")).isNull();
+  }
+
+  @Test
   public void declaredProvidersConcatSuccess() throws Exception {
     evalAndExport(
         ev,
@@ -1703,6 +1816,27 @@
   }
 
   @Test
+  public void declaredProvidersWithInitConcatSuccess() throws Exception {
+    evalAndExport(
+        ev,
+        "def _data_init(x):",
+        "    return {'x': x}",
+        "data, _new_data = provider(init = _data_init)",
+        "dx = data(x = 1)  # normal constructor",
+        "dy = _new_data(y = 'abc')  # raw constructor",
+        "dxy = dx + dy",
+        "x = dxy.x",
+        "y = dxy.y");
+    assertThat(ev.lookup("x")).isEqualTo(StarlarkInt.of(1));
+    assertThat(ev.lookup("y")).isEqualTo("abc");
+    StarlarkProvider dataConstructor = (StarlarkProvider) ev.lookup("data");
+    StructImpl dx = (StructImpl) ev.lookup("dx");
+    assertThat(dx.getProvider()).isEqualTo(dataConstructor);
+    StructImpl dy = (StructImpl) ev.lookup("dy");
+    assertThat(dy.getProvider()).isEqualTo(dataConstructor);
+  }
+
+  @Test
   public void declaredProvidersConcatError() throws Exception {
     evalAndExport(ev, "data1 = provider()", "data2 = provider()");
 
@@ -2098,7 +2232,7 @@
     ev.setFailFast(false);
     evalAndExport(ev, "p = provider(fields = ['x', 'y'])", "p1 = p(x = 1, y = 2, z = 3)");
     MoreAsserts.assertContainsEvent(
-        ev.getEventCollector(), "unexpected keyword z in call to instantiate provider p");
+        ev.getEventCollector(), "got unexpected field 'z' in call to instantiate provider p");
   }
 
   @Test
@@ -2109,7 +2243,8 @@
         "p = provider(fields = [])", //
         "p1 = p(x = 1, y = 2, z = 3)");
     MoreAsserts.assertContainsEvent(
-        ev.getEventCollector(), "unexpected keywords x, y, z in call to instantiate provider p");
+        ev.getEventCollector(),
+        "got unexpected fields 'x', 'y', 'z' in call to instantiate provider p");
   }
 
   @Test