Check C++ toolchain type on cc_proto_aspect when using toolchain resolution

Fixes https://github.com/bazelbuild/bazel/issues/8844

https://github.com/bazelbuild/bazel/issues/7260

RELNOTES: None.
PiperOrigin-RevId: 257788042
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppHelper.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppHelper.java
index 4dd4aaa..868daee 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppHelper.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppHelper.java
@@ -302,6 +302,25 @@
     return getToolchain(ruleContext, dep);
   }
 
+  /**
+   * Makes sure that the given info collection has a {@link CcToolchainProvider} (gives an error
+   * otherwise), and returns a reference to that {@link CcToolchainProvider}. The method never
+   * returns {@code null}, even if there is no toolchain.
+   */
+  public static CcToolchainProvider getToolchain(
+      RuleContext ruleContext, TransitiveInfoCollection dep) {
+    Label toolchainType = getToolchainTypeFromRuleClass(ruleContext);
+    return getToolchain(ruleContext, dep, toolchainType);
+  }
+
+  public static CcToolchainProvider getToolchain(
+      RuleContext ruleContext, TransitiveInfoCollection dep, Label toolchainType) {
+    if (toolchainType != null && useToolchainResolution(ruleContext)) {
+      return getToolchainFromPlatformConstraints(ruleContext, toolchainType);
+    }
+    return getToolchainFromCrosstoolTop(ruleContext, dep);
+  }
+
   /** Returns the c++ toolchain type, or null if it is not specified on the rule class. */
   public static Label getToolchainTypeFromRuleClass(RuleContext ruleContext) {
     Label toolchainType;
@@ -315,21 +334,6 @@
     return toolchainType;
   }
 
-  /**
-   * Makes sure that the given info collection has a {@link CcToolchainProvider} (gives an error
-   * otherwise), and returns a reference to that {@link CcToolchainProvider}. The method never
-   * returns {@code null}, even if there is no toolchain.
-   */
-  public static CcToolchainProvider getToolchain(
-      RuleContext ruleContext, TransitiveInfoCollection dep) {
-
-    Label toolchainType = getToolchainTypeFromRuleClass(ruleContext);
-    if (toolchainType != null && useToolchainResolution(ruleContext)) {
-      return getToolchainFromPlatformConstraints(ruleContext, toolchainType);
-    }
-    return getToolchainFromCrosstoolTop(ruleContext, dep);
-  }
-
   private static CcToolchainProvider getToolchainFromPlatformConstraints(
       RuleContext ruleContext, Label toolchainType) {
     return (CcToolchainProvider) ruleContext.getToolchainContext().forToolchainType(toolchainType);
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoAspect.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoAspect.java
index 9e86685..6b9b310 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoAspect.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/proto/CcProtoAspect.java
@@ -107,7 +107,7 @@
 
     try {
       ConfiguredAspect.Builder result = new ConfiguredAspect.Builder(this, parameters, ruleContext);
-      new Impl(ruleContext, protoInfo, cppSemantics).addProviders(result);
+      new Impl(ruleContext, protoInfo, cppSemantics, ccToolchainType).addProviders(result);
       return result.build();
     } catch (RuleErrorException e) {
       ruleContext.ruleError(e.getMessage());
@@ -145,12 +145,18 @@
     private final ProtoInfo protoInfo;
     private final CppSemantics cppSemantics;
     private final NestedSetBuilder<Artifact> filesBuilder;
+    private final Label ccToolchainType;
 
-    Impl(RuleContext ruleContext, ProtoInfo protoInfo, CppSemantics cppSemantics)
+    Impl(
+        RuleContext ruleContext,
+        ProtoInfo protoInfo,
+        CppSemantics cppSemantics,
+        Label ccToolchainType)
         throws RuleErrorException, InterruptedException {
       this.ruleContext = ruleContext;
       this.protoInfo = protoInfo;
       this.cppSemantics = cppSemantics;
+      this.ccToolchainType = ccToolchainType;
       FeatureConfiguration featureConfiguration = getFeatureConfiguration();
       ProtoConfiguration protoConfiguration = ruleContext.getFragment(ProtoConfiguration.class);
 
@@ -364,10 +370,11 @@
       return helper;
     }
 
-    private static CcToolchainProvider ccToolchain(RuleContext ruleContext) {
+    private CcToolchainProvider ccToolchain(RuleContext ruleContext) {
       return CppHelper.getToolchain(
           ruleContext,
-          ruleContext.getPrerequisite(CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME, TARGET));
+          ruleContext.getPrerequisite(CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME, TARGET),
+          ccToolchainType);
     }
 
     private ImmutableSet<Artifact> getOutputFiles(Iterable<String> suffixes) {