Add attribute validation of mandatory native providers

and use it to validate that :java_toolchain has a JavaToolchainProvider.

--
MOS_MIGRATED_REVID=121396726
diff --git a/src/main/java/com/google/devtools/build/lib/BUILD b/src/main/java/com/google/devtools/build/lib/BUILD
index e1b65493..71ca85b 100644
--- a/src/main/java/com/google/devtools/build/lib/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/BUILD
@@ -405,6 +405,7 @@
         ":events",
         ":skylarkinterface",
         ":syntax",
+        ":transitive-info-provider",
         ":util",
         ":vfs",
         "//src/main/java/com/google/devtools/common/options",
@@ -449,6 +450,11 @@
 )
 
 java_library(
+    name = "transitive-info-provider",
+    srcs = ["analysis/TransitiveInfoProvider.java"],
+)
+
+java_library(
     name = "build-base",
     srcs = glob(
         [
@@ -466,10 +472,16 @@
             "rules/repository/*.java",
             "skyframe/*.java",
         ],
-        exclude = ["analysis/BuildInfo.java"],
+        exclude = [
+            "analysis/BuildInfo.java",
+            "analysis/TransitiveInfoProvider.java",
+        ],
     ) + [
         "runtime/BlazeServerStartupOptions.java",
     ],
+    exports = [
+        ":transitive-info-provider",
+    ],
     deps = [
         ":base-util",
         ":cmdline",
@@ -485,6 +497,7 @@
         ":packages-internal",
         ":shell",
         ":skylarkinterface",
+        ":transitive-info-provider",
         ":util",
         ":vfs",
         "//src/main/java/com/google/devtools/build/lib/actions",
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/RuleContext.java b/src/main/java/com/google/devtools/build/lib/analysis/RuleContext.java
index faa4f2c..961158c 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/RuleContext.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/RuleContext.java
@@ -1710,6 +1710,32 @@
       return missingProviders.toString();
     }
 
+    private String getMissingMandatoryNativeProviders(
+        ConfiguredTarget prerequisite, Attribute attribute) {
+      List<Class<? extends TransitiveInfoProvider>> mandatoryProvidersList =
+          attribute.getMandatoryNativeProviders();
+      if (mandatoryProvidersList.isEmpty()) {
+        return null;
+      }
+      List<Class<? extends TransitiveInfoProvider>> missing = new ArrayList<>();
+      for (Class<? extends TransitiveInfoProvider> provider : mandatoryProvidersList) {
+        if (prerequisite.getProvider(provider) == null) {
+          missing.add(provider);
+        }
+      }
+      if (missing.isEmpty()) {
+        return null;
+      }
+      StringBuilder sb = new StringBuilder();
+      for (Class<? extends TransitiveInfoProvider> provider : missing) {
+        if (sb.length() > 0) {
+          sb.append(", ");
+        }
+        sb.append(provider.getSimpleName());
+      }
+      return sb.toString();
+    }
+
     /**
      * Because some rules still have to use allowedRuleClasses to do rule dependency validation.
      * We implemented the allowedRuleClasses OR mandatoryProvidersList mechanism. Either condition
@@ -1739,6 +1765,16 @@
         }
       }
 
+      if (!attribute.getMandatoryNativeProviders().isEmpty()) {
+        String missing = getMissingMandatoryNativeProviders(prerequisite, attribute);
+        if (missing != null) {
+          attributeError(
+              attribute.getName(),
+              "'" + prerequisite.getLabel() + "' does not have mandatory providers: " + missing);
+          return;
+        }
+      }
+
       if (!attribute.getMandatoryProvidersList().isEmpty()) {
         String missingMandatoryProviders = getMissingMandatoryProviders(prerequisite, attribute);
         if (missingMandatoryProviders != null) {
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
index 390f830..ec80f65 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
@@ -40,6 +40,7 @@
 import com.google.devtools.build.lib.packages.RuleClass.PackageNameConstraint;
 import com.google.devtools.build.lib.packages.TriState;
 import com.google.devtools.build.lib.rules.java.JavaSemantics;
+import com.google.devtools.build.lib.rules.java.JavaToolchainProvider;
 import com.google.devtools.build.lib.syntax.Type;
 import com.google.devtools.build.lib.util.FileTypeSet;
 
@@ -76,7 +77,7 @@
       return builder
           .add(
               attr(":java_toolchain", LABEL)
-                  .allowedRuleClasses("java_toolchain")
+                  .mandatoryNativeProviders(JavaToolchainProvider.class)
                   .value(JavaSemantics.JAVA_TOOLCHAIN))
           .setPreferredDependencyPredicate(JavaSemantics.JAVA_SOURCE)
           .build();
diff --git a/src/main/java/com/google/devtools/build/lib/packages/Attribute.java b/src/main/java/com/google/devtools/build/lib/packages/Attribute.java
index 4b9f7fe..501b46c 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/Attribute.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/Attribute.java
@@ -22,6 +22,7 @@
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
+import com.google.devtools.build.lib.analysis.TransitiveInfoProvider;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.skylarkinterface.SkylarkModule;
 import com.google.devtools.build.lib.syntax.ClassObject;
@@ -396,6 +397,8 @@
     private PredicateWithMessage<Object> allowedValues = null;
     private ImmutableList<ImmutableSet<String>> mandatoryProvidersList =
         ImmutableList.<ImmutableSet<String>>of();
+    private ImmutableList<Class<? extends TransitiveInfoProvider>> mandatoryNativeProviders =
+        ImmutableList.of();
     private Set<RuleAspect> aspects = new LinkedHashSet<>();
 
     /**
@@ -803,6 +806,22 @@
     }
 
     /**
+     * Sets a list of mandatory native providers. Every configured target occurring in this label
+     * type attribute has to provide all the providers, otherwise an error is produced during the
+     * analysis phase.
+     */
+    @SafeVarargs
+    public final Builder<TYPE> mandatoryNativeProviders(
+        Class<? extends TransitiveInfoProvider>... providers) {
+      Preconditions.checkState(
+          (type == BuildType.LABEL) || (type == BuildType.LABEL_LIST),
+          "must be a label-valued type");
+      this.mandatoryNativeProviders =
+          ImmutableList.<Class<? extends TransitiveInfoProvider>>copyOf(providers);
+      return this;
+    }
+
+    /**
      * Sets a list of sets of mandatory Skylark providers. Every configured target occurring in
      * this label type attribute has to provide all the providers from one of those sets,
      * otherwise an error is produced during the analysis phase.
@@ -939,6 +958,7 @@
           condition,
           allowedValues,
           mandatoryProvidersList,
+          mandatoryNativeProviders,
           ImmutableSet.copyOf(aspects));
     }
   }
@@ -1224,6 +1244,8 @@
 
   private final ImmutableList<ImmutableSet<String>> mandatoryProvidersList;
 
+  private final ImmutableList<Class<? extends TransitiveInfoProvider>> mandatoryNativeProviders;
+
   private final ImmutableSet<RuleAspect> aspects;
 
   /**
@@ -1256,6 +1278,7 @@
       Predicate<AttributeMap> condition,
       PredicateWithMessage<Object> allowedValues,
       ImmutableList<ImmutableSet<String>> mandatoryProvidersList,
+      ImmutableList<Class<? extends TransitiveInfoProvider>> mandatoryNativeProviders,
       ImmutableSet<RuleAspect> aspects) {
     Preconditions.checkNotNull(configTransition);
     Preconditions.checkArgument(
@@ -1290,6 +1313,7 @@
     this.condition = condition;
     this.allowedValues = allowedValues;
     this.mandatoryProvidersList = mandatoryProvidersList;
+    this.mandatoryNativeProviders = mandatoryNativeProviders;
     this.aspects = aspects;
   }
 
@@ -1488,6 +1512,11 @@
     return mandatoryProvidersList;
   }
 
+  /** Returns the list of mandatory native providers. */
+  public ImmutableList<Class<? extends TransitiveInfoProvider>> getMandatoryNativeProviders() {
+    return mandatoryNativeProviders;
+  }
+
   public FileTypeSet getAllowedFileTypesPredicate() {
     return allowedFileTypesForLabels;
   }