Prepare tests for moving the Java providers & toolchain rules out of `@_builtins`

Mostly just adds `load` statements everywhere.

PiperOrigin-RevId: 696496082
Change-Id: I5dfc4a58c2c9dffd144b2488fd33a8e665b88db1
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaPackageConfigurationProvider.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaPackageConfigurationProvider.java
index 8bfde1f..c34fab3 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaPackageConfigurationProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaPackageConfigurationProvider.java
@@ -14,10 +14,13 @@
 
 package com.google.devtools.build.lib.rules.java;
 
+import static com.google.devtools.build.lib.skyframe.BzlLoadValue.keyForBuild;
 import static com.google.devtools.build.lib.skyframe.BzlLoadValue.keyForBuiltins;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.actions.Artifact;
+import com.google.devtools.build.lib.analysis.ConfiguredTarget;
 import com.google.devtools.build.lib.analysis.PackageSpecificationProvider;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.nestedset.Depset;
@@ -28,6 +31,7 @@
 import com.google.devtools.build.lib.packages.RuleClass.ConfiguredTargetFactory.RuleErrorException;
 import com.google.devtools.build.lib.packages.StarlarkProviderWrapper;
 import com.google.devtools.build.lib.packages.StructImpl;
+import com.google.devtools.build.lib.skyframe.BzlLoadValue;
 import net.starlark.java.eval.EvalException;
 import net.starlark.java.eval.Sequence;
 import net.starlark.java.eval.Starlark;
@@ -37,8 +41,10 @@
 @Immutable
 public final class JavaPackageConfigurationProvider implements StarlarkValue {
 
-  public static final StarlarkProviderWrapper<JavaPackageConfigurationProvider> PROVIDER =
+  private static final StarlarkProviderWrapper<JavaPackageConfigurationProvider> PROVIDER =
       new Provider();
+  private static final StarlarkProviderWrapper<JavaPackageConfigurationProvider> BUILTINS_PROVIDER =
+      new BuiltinsProvider();
 
   private final StructImpl underlying;
 
@@ -46,6 +52,16 @@
     this.underlying = underlying;
   }
 
+  @VisibleForTesting
+  public static JavaPackageConfigurationProvider get(ConfiguredTarget target)
+      throws RuleErrorException {
+    JavaPackageConfigurationProvider info = target.get(PROVIDER);
+    if (info == null) {
+      info = target.get(BUILTINS_PROVIDER);
+    }
+    return info;
+  }
+
   /** Package specifications for which the configuration should be applied. */
   private ImmutableList<PackageSpecificationProvider> packageSpecifications()
       throws RuleErrorException {
@@ -96,11 +112,15 @@
   private static class Provider extends StarlarkProviderWrapper<JavaPackageConfigurationProvider> {
 
     private Provider() {
-      super(
-          keyForBuiltins(
+      this(
+          keyForBuild(
               Label.parseCanonicalUnchecked(
-                  "@_builtins//:common/java/java_package_configuration.bzl")),
-          "JavaPackageConfigurationInfo");
+                  JavaSemantics.RULES_JAVA_PROVIDER_LABELS_PREFIX
+                      + "java/common/rules:java_package_configuration.bzl")));
+    }
+
+    private Provider(BzlLoadValue.Key key) {
+      super(key, "JavaPackageConfigurationInfo");
     }
 
     @Override
@@ -115,10 +135,21 @@
     }
   }
 
+  private static class BuiltinsProvider extends JavaPackageConfigurationProvider.Provider {
+
+    private BuiltinsProvider() {
+      super(
+          keyForBuiltins(
+              Label.parseCanonicalUnchecked(
+                  "@_builtins//:common/java/java_package_configuration.bzl")));
+    }
+  }
+
   static ImmutableList<JavaPackageConfigurationProvider> wrapSequence(Sequence<StructImpl> sequence)
       throws RuleErrorException {
     ImmutableList.Builder<JavaPackageConfigurationProvider> builder = ImmutableList.builder();
     for (StructImpl struct : sequence) {
+      // this result isn't propagated back to Starlark so we just need any type
       builder.add(PROVIDER.wrap(struct));
     }
     return builder.build();
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java b/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
index 9c19232..5ee06d3 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
@@ -1262,7 +1262,7 @@
         """);
     scratch.file(
         "test/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1375,7 +1375,8 @@
       throws Exception {
     scratch.file(
         "test/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl',"
+            + " 'java_common', 'JavaInfo', 'JavaPluginInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1432,7 +1433,8 @@
       throws Exception {
     scratch.file(
         "test/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl',"
+            + " 'java_common', 'JavaInfo', 'JavaPluginInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1489,7 +1491,7 @@
       throws Exception {
     scratch.file(
         "test/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1533,7 +1535,7 @@
       throws Exception {
     scratch.file(
         "test/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1578,7 +1580,7 @@
           throws Exception {
     scratch.file(
         "bazel_internal/test_rules/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -1627,7 +1629,7 @@
           throws Exception {
     scratch.file(
         "bazel_internal/test_rules/defs.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib_' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/android/AndroidStarlarkCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/android/AndroidStarlarkCommonTest.java
index ec02462..d339125 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/android/AndroidStarlarkCommonTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/android/AndroidStarlarkCommonTest.java
@@ -39,6 +39,8 @@
     scratch.file(
         "java/android/compatible.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
+
         def _impl(ctx):
             return [
                 android_common.enable_implicit_sourceless_deps_exports_compatibility(
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoRoundtripTest.java b/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoRoundtripTest.java
index ac2e4ac..402656b 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoRoundtripTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoRoundtripTest.java
@@ -53,6 +53,7 @@
     scratch.file(
         "javainfo/javainfo_to_dict.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         load("//tools/build_defs/inspect:struct_to_dict.bzl", "struct_to_dict")
         Info = provider()
         def _impl(ctx):
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoStarlarkApiTest.java b/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoStarlarkApiTest.java
index 59fbae4..6f3f05d 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoStarlarkApiTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/JavaInfoStarlarkApiTest.java
@@ -1236,6 +1236,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         MyInfo = provider()
 
         def _impl(ctx):
@@ -1536,7 +1537,8 @@
       assertThat(useIJar && stampJar).isFalse();
       ImmutableList.Builder<String> lines = ImmutableList.builder();
       lines.add(
-          "load('@rules_java//java:defs.bzl', 'java_common')",
+          "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo',"
+              + " 'JavaPluginInfo')",
           "result = provider()",
           "def _impl(ctx):",
           "  ctx.actions.write(ctx.outputs.output_jar, 'JavaInfo API Test', is_executable=False) ",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/JavaStarlarkApiTest.java b/src/test/java/com/google/devtools/build/lib/rules/java/JavaStarlarkApiTest.java
index 7d37df4..174ae0d 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/JavaStarlarkApiTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/JavaStarlarkApiTest.java
@@ -135,6 +135,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -185,6 +186,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -236,6 +238,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -286,6 +289,7 @@
     scratch.file(
         "java/test/extension.bzl",
         """
+        load("@rules_java//java/common:java_common.bzl", "java_common")
         result = provider()
 
         def impl(ctx):
@@ -364,6 +368,8 @@
     scratch.file(
         "java/test/extension.bzl",
         """
+        load("@rules_java//java/common:java_plugin_info.bzl",
+         "JavaPluginInfo")
         result = provider()
 
         def impl(ctx):
@@ -409,6 +415,7 @@
     scratch.file(
         "java/test/custom_rule.bzl",
         """
+        load("@rules_java//java/common:java_common.bzl", "java_common")
         def _impl(ctx):
             jacoco = ctx.attr._java_toolchain[java_common.JavaToolchainInfo].jacocorunner
             return [
@@ -456,6 +463,7 @@
     scratch.file(
         "java/test/extension.bzl",
         """
+        load("@rules_java//java/common:java_common.bzl", "java_common")
         result = provider()
 
         def impl(ctx):
@@ -473,6 +481,7 @@
         """);
     scratch.file(
         "java/test/custom_rule.bzl",
+        "load('@rules_java//java/common:java_common.bzl', 'java_common')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  compilation_provider = java_common.compile(",
@@ -554,6 +563,8 @@
         """);
     scratch.file(
         "java/test/custom_rule.bzl",
+        "load('@rules_java//java/common:java_common.bzl', 'java_common')",
+        "load('@rules_java//java/common:java_info.bzl', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  deps = [dep[JavaInfo] for dep in ctx.attr.deps]",
@@ -628,7 +639,8 @@
     JavaTestUtil.writeBuildFileForJavaToolchain(scratch);
     scratch.file(
         "java/test/custom_rule.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo',"
+            + " 'JavaPluginInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  return java_common.compile(",
@@ -670,7 +682,7 @@
     JavaTestUtil.writeBuildFileForJavaToolchain(scratch);
     scratch.file(
         "java/test/custom_rule.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  return java_common.compile(",
@@ -1172,6 +1184,7 @@
     scratch.file(
         "java/test/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def impl(ctx):
@@ -1259,6 +1272,7 @@
     String pkg = apiGenerating ? "java/getplugininfo" : "java/getapiplugininfo";
     scratch.file(
         pkg + "/extension.bzl",
+        "load('@rules_java//java:defs.bzl', 'JavaInfo', 'JavaPluginInfo')",
         "result = provider()",
         "def impl(ctx):",
         "   depj = ctx.attr.dep[" + provider + "]",
@@ -1436,6 +1450,9 @@
     scratch.file(
         "java/test/myplugin.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
+        load("@rules_java//java/common:java_plugin_info.bzl",
+         "JavaPluginInfo")
         def _impl(ctx):
             output_jar = ctx.actions.declare_file("lib.jar")
             ctx.actions.write(output_jar, "")
@@ -1499,6 +1516,9 @@
     scratch.file(
         "java/test/myplugin.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
+        load("@rules_java//java/common:java_plugin_info.bzl",
+         "JavaPluginInfo")
         def _impl(ctx):
             output_jar = ctx.actions.declare_file("lib.jar")
             ctx.actions.write(output_jar, "")
@@ -1562,6 +1582,9 @@
     scratch.file(
         "java/test/myplugin.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
+        load("@rules_java//java/common:java_plugin_info.bzl",
+         "JavaPluginInfo")
         def _impl(ctx):
             output_jar = ctx.actions.declare_file("lib.jar")
             ctx.actions.write(output_jar, "")
@@ -1623,6 +1646,9 @@
     scratch.file(
         "java/test/myplugin.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
+        load("@rules_java//java/common:java_plugin_info.bzl",
+         "JavaPluginInfo")
         def _impl(ctx):
             output_jar = ctx.actions.declare_file("lib.jar")
             ctx.actions.write(output_jar, "")
@@ -1688,6 +1714,7 @@
     scratch.file(
         "java/test/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def impl(ctx):
@@ -1768,6 +1795,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         my_provider = provider()
 
         def _impl(ctx):
@@ -1812,6 +1840,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             dep_params = ctx.attr.dep[JavaInfo]
             return [dep_params]
@@ -1866,6 +1895,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             dep_params = ctx.attr.dep[JavaInfo]
             return [dep_params]
@@ -1953,6 +1983,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             dep_params = ctx.attr.dep[JavaInfo]
             return [dep_params]
@@ -2012,6 +2043,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             dep_params = ctx.attr.dep[JavaInfo]
             return [dep_params]
@@ -2066,6 +2098,7 @@
     scratch.file(
         "foo/bad_rules.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def make_file(ctx):
             f = ctx.actions.declare_file("out")
             ctx.actions.write(f, "out")
@@ -2121,6 +2154,7 @@
     scratch.file(
         "foo/javainfo_rules.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def make_file(ctx):
             f = ctx.actions.declare_file("out")
             ctx.actions.write(f, "out")
@@ -2149,6 +2183,7 @@
     scratch.file(
         "foo/javainfo_rules.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def make_file(ctx):
             f = ctx.actions.declare_file("out")
             ctx.actions.write(f, "out")
@@ -2176,6 +2211,7 @@
     scratch.file(
         "foo/javainfo_rules.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def make_file(ctx):
             f = ctx.actions.declare_file("out")
             ctx.actions.write(f, "out")
@@ -2205,6 +2241,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2253,6 +2290,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2311,6 +2349,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2369,6 +2408,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2426,6 +2466,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2705,6 +2746,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         my_provider = provider()
 
         def _impl(ctx):
@@ -2811,6 +2853,7 @@
     scratch.file(
         "java/test/custom_rule.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             output_jar = ctx.actions.declare_file("lib" + ctx.label.name + ".jar")
             ctx.actions.write(output_jar, "")
@@ -2858,6 +2901,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2904,6 +2948,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -2953,6 +2998,7 @@
     scratch.file(
         "foo/extension.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         result = provider()
 
         def _impl(ctx):
@@ -3157,7 +3203,7 @@
         """);
     scratch.file(
         "java/test/custom_rule.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  java_info = java_common.compile(",
@@ -3243,7 +3289,7 @@
         """);
     scratch.file(
         "java/test/custom_rule.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  compilation_provider = java_common.compile(",
@@ -3352,6 +3398,7 @@
         "foo/custom_library.bzl",
         """
         load("@rules_java//java:defs.bzl", "java_common")
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             java_provider = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps])
             if not ctx.attr.strict_deps:
@@ -3404,6 +3451,7 @@
         "foo/custom_library.bzl",
         """
         load("@rules_java//java:defs.bzl", "java_common")
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             java_provider = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps])
             if not ctx.attr.strict_deps:
@@ -3518,6 +3566,7 @@
         "foo/custom_library.bzl",
         """
         load("@rules_java//java:defs.bzl", "java_common")
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             java_provider = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps])
             return [java_provider]
@@ -3784,7 +3833,7 @@
     JavaTestUtil.writeBuildFileForJavaToolchain(scratch);
     scratch.file(
         "foo/java_custom_library.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib%s.jar' % ctx.label.name)",
         "  deps = [deps[JavaInfo] for deps in ctx.attr.deps]",
@@ -3840,7 +3889,7 @@
     JavaTestUtil.writeBuildFileForJavaToolchain(scratch);
     scratch.file(
         "foo/java_custom_library.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib%s.jar' % ctx.label.name)",
         "  exports = [export[JavaInfo] for export in ctx.attr.exports]",
@@ -4083,6 +4132,7 @@
     scratch.file(
         "a/rule.bzl",
         """
+        load("@rules_java//java/common:java_common.bzl", "java_common")
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
             return DefaultInfo(
@@ -4130,7 +4180,8 @@
     JavaTestUtil.writeBuildFileForJavaToolchain(scratch);
     scratch.file(
         "foo/custom_rule.bzl",
-        "load('@rules_java//java:defs.bzl', 'java_common')",
+        "load('@rules_java//java:defs.bzl', 'java_common', 'JavaInfo',"
+            + " 'JavaPluginInfo')",
         "def _impl(ctx):",
         "  output_jar = ctx.actions.declare_file('lib' + ctx.label.name + '.jar')",
         "  compilation_provider = java_common.compile(",
@@ -4642,7 +4693,7 @@
     scratch.file(
         "foo/custom_library.bzl",
         """
-        load("@rules_java//java:defs.bzl", "java_common")
+        load("@rules_java//java:defs.bzl", "java_common", "JavaInfo")
         def _impl(ctx):
             java_provider = java_common.merge([dep[JavaInfo] for dep in ctx.attr.deps])
             return [java_provider]
@@ -4710,6 +4761,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -4764,6 +4816,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -4814,6 +4867,7 @@
         "a/rule.bzl",
         """
         load("//myinfo:myinfo.bzl", "MyInfo")
+        load("@rules_java//java/common:java_common.bzl", "java_common")
 
         def _impl(ctx):
             provider = ctx.attr._java_runtime[java_common.JavaRuntimeInfo]
@@ -4919,6 +4973,7 @@
     scratch.file(
         "foo/rule.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             cc_info = ctx.attr.dep[CcInfo]
             JavaInfo(output_jar = None, compile_jar = None, deps = [cc_info])
diff --git a/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
index c001645..00153ca 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/java/proto/StarlarkJavaLiteProtoLibraryTest.java
@@ -391,6 +391,7 @@
     scratch.file(
         "proto/extensions.bzl",
         """
+        load("@rules_java//java/common:java_info.bzl", "JavaInfo")
         def _impl(ctx):
             print(ctx.attr.dep[JavaInfo])
 
diff --git a/src/test/shell/integration/java_integration_test.sh b/src/test/shell/integration/java_integration_test.sh
index f51325a..8a99009 100755
--- a/src/test/shell/integration/java_integration_test.sh
+++ b/src/test/shell/integration/java_integration_test.sh
@@ -254,8 +254,11 @@
     local -r javabase="${BAZEL_RUNFILES}/${runfiles_relative_javabase}"
   fi
 
+  add_rules_java "MODULE.bazel"
+
   mkdir -p "$pkg/jvm"
   cat > "$pkg/jvm/BUILD" <<EOF
+load("@rules_java//java/toolchains:java_runtime.bzl", "java_runtime")
 package(default_visibility=["//visibility:public"])
 java_runtime(
     name='runtime',