Add an attribute to java_toolchain pointing to java_runtime target.

PiperOrigin-RevId: 337811929
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/JavaRules.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/JavaRules.java
index a181f0f..f7b68f5 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/JavaRules.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/JavaRules.java
@@ -35,7 +35,6 @@
 import com.google.devtools.build.lib.rules.java.JavaOptions;
 import com.google.devtools.build.lib.rules.java.JavaPackageConfigurationRule;
 import com.google.devtools.build.lib.rules.java.JavaRuleClasses.IjarBaseRule;
-import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaHostRuntimeBaseRule;
 import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaRuntimeBaseRule;
 import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaToolchainBaseRule;
 import com.google.devtools.build.lib.rules.java.JavaRuntimeAliasRule;
@@ -71,7 +70,6 @@
     builder.addRuleDefinition(new IjarBaseRule());
     builder.addRuleDefinition(new JavaToolchainBaseRule());
     builder.addRuleDefinition(new JavaRuntimeBaseRule());
-    builder.addRuleDefinition(new JavaHostRuntimeBaseRule());
     builder.addRuleDefinition(new BazelJavaRuleClasses.JavaBaseRule());
     builder.addRuleDefinition(new ProguardLibraryRule());
     builder.addRuleDefinition(new JavaImportBaseRule());
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/proto/BazelJavaLiteProtoLibraryRule.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/proto/BazelJavaLiteProtoLibraryRule.java
index 603c544..7342224 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/proto/BazelJavaLiteProtoLibraryRule.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/proto/BazelJavaLiteProtoLibraryRule.java
@@ -30,7 +30,6 @@
 import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
 import com.google.devtools.build.lib.rules.java.JavaConfiguration;
 import com.google.devtools.build.lib.rules.java.JavaInfo;
-import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaHostRuntimeBaseRule;
 import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaToolchainBaseRule;
 import com.google.devtools.build.lib.rules.java.proto.JavaLiteProtoLibrary;
 import com.google.devtools.build.lib.rules.java.proto.JavaProtoAspectCommon;
@@ -74,10 +73,7 @@
     return RuleDefinition.Metadata.builder()
         .name("java_lite_proto_library")
         .factoryClass(JavaLiteProtoLibrary.class)
-        .ancestors(
-            BaseRuleClasses.RuleBase.class,
-            JavaToolchainBaseRule.class,
-            JavaHostRuntimeBaseRule.class)
+        .ancestors(BaseRuleClasses.RuleBase.class, JavaToolchainBaseRule.class)
         .build();
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaImportBaseRule.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaImportBaseRule.java
index 6f804ef..91c7102 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaImportBaseRule.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaImportBaseRule.java
@@ -27,7 +27,6 @@
 import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassType;
 import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
 import com.google.devtools.build.lib.rules.cpp.CppConfiguration;
-import com.google.devtools.build.lib.rules.java.JavaRuleClasses.JavaHostRuntimeBaseRule;
 
 /** A base rule for building the java_import rule. */
 public class JavaImportBaseRule implements RuleDefinition {
@@ -74,10 +73,7 @@
     return RuleDefinition.Metadata.builder()
         .name("$java_import_base")
         .type(RuleClassType.ABSTRACT)
-        .ancestors(
-            BaseRuleClasses.RuleBase.class,
-            ProguardLibraryRule.class,
-            JavaHostRuntimeBaseRule.class)
+        .ancestors(BaseRuleClasses.RuleBase.class, ProguardLibraryRule.class)
         .build();
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuleClasses.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuleClasses.java
index 02103dc..b452bda 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuleClasses.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuleClasses.java
@@ -20,7 +20,6 @@
 import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.analysis.RuleDefinition;
 import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
-import com.google.devtools.build.lib.analysis.config.HostTransition;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.packages.RuleClass;
 import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassType;
@@ -31,7 +30,6 @@
 
   public static final String JAVA_TOOLCHAIN_ATTRIBUTE_NAME = "$java_toolchain";
   public static final String JAVA_RUNTIME_ATTRIBUTE_NAME = "$jvm";
-  public static final String HOST_JAVA_RUNTIME_ATTRIBUTE_NAME = "$host_jdk";
 
   /** Common attributes for rules that depend on ijar. */
   public static final class IjarBaseRule implements RuleDefinition {
@@ -90,29 +88,6 @@
       return RuleDefinition.Metadata.builder()
           .name("$java_runtime_toolchain_base_rule")
           .type(RuleClassType.ABSTRACT)
-          .ancestors(JavaHostRuntimeBaseRule.class)
-          .build();
-    }
-  }
-
-  /** Common attributes for rules that use the host Java runtime. */
-  public static final class JavaHostRuntimeBaseRule implements RuleDefinition {
-    @Override
-    public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment env) {
-      return builder
-          .add(
-              attr(HOST_JAVA_RUNTIME_ATTRIBUTE_NAME, LABEL)
-                  .cfg(HostTransition.createFactory())
-                  .value(JavaSemantics.hostJdkAttribute(env))
-                  .mandatoryProviders(ToolchainInfo.PROVIDER.id()))
-          .build();
-    }
-
-    @Override
-    public Metadata getMetadata() {
-      return RuleDefinition.Metadata.builder()
-          .name("$java_host_runtime_toolchain_base_rule")
-          .type(RuleClassType.ABSTRACT)
           .build();
     }
   }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuntimeInfo.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuntimeInfo.java
index 22552e6..d57e291 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuntimeInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaRuntimeInfo.java
@@ -14,13 +14,11 @@
 
 package com.google.devtools.build.lib.rules.java;
 
-import static com.google.devtools.build.lib.rules.java.JavaRuleClasses.HOST_JAVA_RUNTIME_ATTRIBUTE_NAME;
 import static com.google.devtools.build.lib.rules.java.JavaRuleClasses.JAVA_RUNTIME_ATTRIBUTE_NAME;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.analysis.RuleContext;
-import com.google.devtools.build.lib.analysis.RuleErrorConsumer;
 import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.collect.nestedset.Depset;
@@ -63,7 +61,7 @@
   // Helper methods to access an instance of JavaRuntimeInfo.
 
   public static JavaRuntimeInfo forHost(RuleContext ruleContext) {
-    return from(ruleContext, HOST_JAVA_RUNTIME_ATTRIBUTE_NAME);
+    return JavaToolchainProvider.from(ruleContext).getJavaRuntime();
   }
 
   public static JavaRuntimeInfo from(RuleContext ruleContext) {
@@ -80,15 +78,7 @@
       return null;
     }
 
-    return from(prerequisite, ruleContext);
-  }
-
-  // TODO(katre): When all external callers are converted to use toolchain resolution, make this
-  // method private.
-  @Nullable
-  static JavaRuntimeInfo from(
-      TransitiveInfoCollection collection, RuleErrorConsumer errorConsumer) {
-    return (JavaRuntimeInfo) collection.get(ToolchainInfo.PROVIDER);
+    return (JavaRuntimeInfo) prerequisite.get(ToolchainInfo.PROVIDER);
   }
 
   private final NestedSet<Artifact> javaBaseInputs;
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchain.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchain.java
index befbad2..9b068ff 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchain.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchain.java
@@ -34,6 +34,7 @@
 import com.google.devtools.build.lib.analysis.Runfiles;
 import com.google.devtools.build.lib.analysis.RunfilesProvider;
 import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
+import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.nestedset.NestedSet;
 import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
@@ -122,6 +123,9 @@
 
     FilesToRunProvider jacocoRunner = ruleContext.getExecutablePrerequisite("jacocorunner");
 
+    JavaRuntimeInfo javaRuntime =
+        (JavaRuntimeInfo) ruleContext.getPrerequisite("java_runtime").get(ToolchainInfo.PROVIDER);
+
     JavaToolchainProvider provider =
         JavaToolchainProvider.create(
             ruleContext.getLabel(),
@@ -151,7 +155,8 @@
             packageConfiguration,
             jacocoRunner,
             proguardAllowlister,
-            semantics);
+            semantics,
+            javaRuntime);
     RuleConfiguredTargetBuilder builder =
         new RuleConfiguredTargetBuilder(ruleContext)
             .addStarlarkTransitiveInfo(JavaToolchainProvider.LEGACY_NAME, provider)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainProvider.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainProvider.java
index da22ce0..b20e208 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainProvider.java
@@ -97,7 +97,8 @@
       ImmutableList<JavaPackageConfigurationProvider> packageConfiguration,
       FilesToRunProvider jacocoRunner,
       FilesToRunProvider proguardAllowlister,
-      JavaSemantics javaSemantics) {
+      JavaSemantics javaSemantics,
+      JavaRuntimeInfo javaRuntime) {
     return new JavaToolchainProvider(
         label,
         bootclasspath,
@@ -126,7 +127,8 @@
         packageConfiguration,
         jacocoRunner,
         proguardAllowlister,
-        javaSemantics);
+        javaSemantics,
+        javaRuntime);
   }
 
   private final Label label;
@@ -157,6 +159,7 @@
   private final FilesToRunProvider jacocoRunner;
   private final FilesToRunProvider proguardAllowlister;
   private final JavaSemantics javaSemantics;
+  private final JavaRuntimeInfo javaRuntime;
 
   @VisibleForSerialization
   JavaToolchainProvider(
@@ -187,7 +190,8 @@
       ImmutableList<JavaPackageConfigurationProvider> packageConfiguration,
       FilesToRunProvider jacocoRunner,
       FilesToRunProvider proguardAllowlister,
-      JavaSemantics javaSemantics) {
+      JavaSemantics javaSemantics,
+      JavaRuntimeInfo javaRuntime) {
     super(ImmutableMap.of(), Location.BUILTIN);
 
     this.label = label;
@@ -218,6 +222,7 @@
     this.jacocoRunner = jacocoRunner;
     this.proguardAllowlister = proguardAllowlister;
     this.javaSemantics = javaSemantics;
+    this.javaRuntime = javaRuntime;
   }
 
   /** Returns the label for this {@code java_toolchain}. */
@@ -393,6 +398,10 @@
     return javaSemantics;
   }
 
+  public JavaRuntimeInfo getJavaRuntime() {
+    return javaRuntime;
+  }
+
   /** Returns the input Java language level */
   // TODO(cushon): remove this API; it bakes a deprecated detail of the javac API into Bazel
   @Override
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainRule.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainRule.java
index 4f46dec..b76be38 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainRule.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaToolchainRule.java
@@ -29,6 +29,7 @@
 import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
 import com.google.devtools.build.lib.analysis.config.ExecutionTransitionFactory;
 import com.google.devtools.build.lib.analysis.config.transitions.NoTransition;
+import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.packages.RuleClass;
 import com.google.devtools.build.lib.util.FileTypeSet;
 import java.util.List;
@@ -282,6 +283,19 @@
                 .cfg(ExecutionTransitionFactory.create())
                 .allowedFileTypes(FileTypeSet.ANY_FILE)
                 .exec())
+        /* <!-- #BLAZE_RULE(java_toolchain).ATTRIBUTE(java_runtime) -->
+        The java_runtime to use with this toolchain. It defaults to java_runtime
+        in execution configuration.
+        <!-- #END_BLAZE_RULE.ATTRIBUTE --> */
+        .add(
+            attr("java_runtime", LABEL)
+                .cfg(ExecutionTransitionFactory.create())
+                // TODO(b/171140578): remove default value and set to mandatory after it is set on
+                // all toolchains
+                .value(JavaSemantics.hostJdkAttribute(env))
+                .mandatoryProviders(ToolchainInfo.PROVIDER.id())
+                .allowedFileTypes(FileTypeSet.ANY_FILE)
+                .useOutputLicenses())
         .build();
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaLiteProtoAspect.java b/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaLiteProtoAspect.java
index 87d28e2..2ebd1c8 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaLiteProtoAspect.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaLiteProtoAspect.java
@@ -16,7 +16,6 @@
 
 import static com.google.devtools.build.lib.packages.Attribute.attr;
 import static com.google.devtools.build.lib.packages.BuildType.LABEL;
-import static com.google.devtools.build.lib.rules.java.JavaRuleClasses.HOST_JAVA_RUNTIME_ATTRIBUTE_NAME;
 import static com.google.devtools.build.lib.rules.java.proto.JplCcLinkParams.createCcLinkingInfo;
 import static com.google.devtools.build.lib.rules.java.proto.StrictDepsUtils.createNonStrictCompilationArgsProvider;
 
@@ -29,7 +28,6 @@
 import com.google.devtools.build.lib.analysis.RuleContext;
 import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
 import com.google.devtools.build.lib.analysis.TransitiveInfoProvider;
-import com.google.devtools.build.lib.analysis.config.HostTransition;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
@@ -68,7 +66,6 @@
   private final JavaSemantics javaSemantics;
 
   private final String defaultProtoToolchainLabel;
-  private final Label hostJdkAttribute;
   private final Label javaToolchainAttribute;
 
   public JavaLiteProtoAspect(
@@ -77,7 +74,6 @@
       RuleDefinitionEnvironment env) {
     this.javaSemantics = javaSemantics;
     this.defaultProtoToolchainLabel = defaultProtoToolchainLabel;
-    this.hostJdkAttribute = JavaSemantics.hostJdkAttribute(env);
     this.javaToolchainAttribute = JavaSemantics.javaToolchainAttribute(env);
   }
 
@@ -120,11 +116,6 @@
                             ProtoLangToolchainProvider.class))
                     .value(getProtoToolchainLabel(defaultProtoToolchainLabel)))
             .add(
-                attr(HOST_JAVA_RUNTIME_ATTRIBUTE_NAME, LABEL)
-                    .cfg(HostTransition.createFactory())
-                    .value(hostJdkAttribute)
-                    .mandatoryProviders(ToolchainInfo.PROVIDER.id()))
-            .add(
                 attr(JavaRuleClasses.JAVA_TOOLCHAIN_ATTRIBUTE_NAME, LABEL)
                     .useOutputLicenses()
                     .value(javaToolchainAttribute)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaProtoAspect.java b/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaProtoAspect.java
index 00b995a..3c3a8d9 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaProtoAspect.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/proto/JavaProtoAspect.java
@@ -16,7 +16,6 @@
 
 import static com.google.devtools.build.lib.packages.Attribute.attr;
 import static com.google.devtools.build.lib.packages.BuildType.LABEL;
-import static com.google.devtools.build.lib.rules.java.JavaRuleClasses.HOST_JAVA_RUNTIME_ATTRIBUTE_NAME;
 import static com.google.devtools.build.lib.rules.java.proto.JplCcLinkParams.createCcLinkingInfo;
 import static com.google.devtools.build.lib.rules.java.proto.StrictDepsUtils.createNonStrictCompilationArgsProvider;
 
@@ -29,7 +28,6 @@
 import com.google.devtools.build.lib.analysis.PlatformConfiguration;
 import com.google.devtools.build.lib.analysis.RuleContext;
 import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
-import com.google.devtools.build.lib.analysis.config.HostTransition;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
@@ -59,7 +57,6 @@
 /** An Aspect which JavaProtoLibrary injects to build Java SPEED protos. */
 public class JavaProtoAspect extends NativeAspectClass implements ConfiguredAspectFactory {
 
-  private final Label hostJdkAttribute;
   private final Label javaToolchainAttribute;
 
   private static LabelLateBoundDefault<?> getSpeedProtoToolchainLabel(String defaultValue) {
@@ -83,7 +80,6 @@
     this.rpcSupport = Preconditions.checkNotNull(rpcSupport);
     this.defaultSpeedProtoToolchainLabel =
         Preconditions.checkNotNull(defaultSpeedProtoToolchainLabel);
-    this.hostJdkAttribute = JavaSemantics.hostJdkAttribute(env);
     this.javaToolchainAttribute = JavaSemantics.javaToolchainAttribute(env);
   }
 
@@ -140,11 +136,6 @@
                     .legacyAllowAnyFileType()
                     .value(getSpeedProtoToolchainLabel(defaultSpeedProtoToolchainLabel)))
             .add(
-                attr(HOST_JAVA_RUNTIME_ATTRIBUTE_NAME, LABEL)
-                    .cfg(HostTransition.createFactory())
-                    .value(hostJdkAttribute)
-                    .mandatoryProviders(ToolchainInfo.PROVIDER.id()))
-            .add(
                 attr(JavaRuleClasses.JAVA_TOOLCHAIN_ATTRIBUTE_NAME, LABEL)
                     .useOutputLicenses()
                     .value(javaToolchainAttribute)