make execGroups available to the RuleClass (like requiredToolchains)

PiperOrigin-RevId: 304174397
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleClassFunctions.java b/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleClassFunctions.java
index 2e27ccd..2bc063b 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleClassFunctions.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/skylark/SkylarkRuleClassFunctions.java
@@ -368,6 +368,10 @@
 
     builder.addRequiredToolchains(parseToolchains(toolchains, thread));
 
+    if (execGroups != Starlark.NONE) {
+      builder.addExecGroups(castMap(execGroups, String.class, ExecGroup.class, "exec_group"));
+    }
+
     if (!buildSetting.equals(Starlark.NONE) && !cfg.equals(Starlark.NONE)) {
       throw Starlark.errorf(
           "Build setting rules cannot use the `cfg` param to apply transitions to themselves.");
diff --git a/src/main/java/com/google/devtools/build/lib/packages/RuleClass.java b/src/main/java/com/google/devtools/build/lib/packages/RuleClass.java
index 666451c..ab180ad 100644
--- a/src/main/java/com/google/devtools/build/lib/packages/RuleClass.java
+++ b/src/main/java/com/google/devtools/build/lib/packages/RuleClass.java
@@ -710,6 +710,7 @@
     private boolean useToolchainResolution = true;
     private Set<Label> executionPlatformConstraints = new HashSet<>();
     private OutputFile.Kind outputFileKind = OutputFile.Kind.FILE;
+    private final Map<String, ExecGroup> execGroups = new HashMap<>();
 
     /**
      * Constructs a new {@code RuleClassBuilder} using all attributes from all
@@ -743,6 +744,14 @@
         addRequiredToolchains(parent.getRequiredToolchains());
         useToolchainResolution = parent.useToolchainResolution;
         addExecutionPlatformConstraints(parent.getExecutionPlatformConstraints());
+        try {
+          addExecGroups(parent.getExecGroups());
+        } catch (DuplicateExecGroupError e) {
+          throw new IllegalArgumentException(
+              String.format(
+                  "An execution group named '%s' is inherited multiple times in %s ruleclass",
+                  e.getDuplicateGroup(), name));
+        }
 
         for (Attribute attribute : parent.getAttributes()) {
           String attrName = attribute.getName();
@@ -862,6 +871,7 @@
           requiredToolchains,
           useToolchainResolution,
           executionPlatformConstraints,
+          execGroups,
           outputFileKind,
           attributes.values(),
           buildSetting);
@@ -1365,6 +1375,36 @@
     }
 
     /**
+     * Adds execution groups to this rule class. Errors out if multiple groups with the same name
+     * are added.
+     */
+    public Builder addExecGroups(Map<String, ExecGroup> execGroups) throws DuplicateExecGroupError {
+      for (Map.Entry<String, ExecGroup> group : execGroups.entrySet()) {
+        String name = group.getKey();
+        if (this.execGroups.put(name, group.getValue()) != null) {
+          throw new DuplicateExecGroupError(name);
+        }
+      }
+      return this;
+    }
+
+    /** An error to help report {@link ExecGroup}s with the same name */
+    static class DuplicateExecGroupError extends EvalException {
+      private final String duplicateGroup;
+
+      DuplicateExecGroupError(String duplicateGroup) {
+        super(
+            null,
+            String.format("Multiple execution groups with the same name: '%s'.", duplicateGroup));
+        this.duplicateGroup = duplicateGroup;
+      }
+
+      String getDuplicateGroup() {
+        return duplicateGroup;
+      }
+    }
+
+    /**
      * Causes rules to use toolchain resolution to determine the execution platform and toolchains.
      * Rules that are part of configuring toolchains and platforms should set this to {@code false}.
      */
@@ -1524,6 +1564,7 @@
   private final ImmutableSet<Label> requiredToolchains;
   private final boolean useToolchainResolution;
   private final ImmutableSet<Label> executionPlatformConstraints;
+  private final ImmutableMap<String, ExecGroup> execGroups;
 
   /**
    * Constructs an instance of RuleClass whose name is 'name', attributes are 'attributes'. The
@@ -1579,6 +1620,7 @@
       Set<Label> requiredToolchains,
       boolean useToolchainResolution,
       Set<Label> executionPlatformConstraints,
+      Map<String, ExecGroup> execGroups,
       OutputFile.Kind outputFileKind,
       Collection<Attribute> attributes,
       @Nullable BuildSetting buildSetting) {
@@ -1618,6 +1660,7 @@
     this.requiredToolchains = ImmutableSet.copyOf(requiredToolchains);
     this.useToolchainResolution = useToolchainResolution;
     this.executionPlatformConstraints = ImmutableSet.copyOf(executionPlatformConstraints);
+    this.execGroups = ImmutableMap.copyOf(execGroups);
     this.buildSetting = buildSetting;
 
     // Create the index and collect non-configurable attributes.
@@ -2536,6 +2579,10 @@
     return executionPlatformConstraints;
   }
 
+  public ImmutableMap<String, ExecGroup> getExecGroups() {
+    return execGroups;
+  }
+
   public OutputFile.Kind  getOutputFileKind() {
     return outputFileKind;
   }
diff --git a/src/test/java/com/google/devtools/build/lib/packages/RuleClassBuilderTest.java b/src/test/java/com/google/devtools/build/lib/packages/RuleClassBuilderTest.java
index 49e1d23..9cd63ce 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/RuleClassBuilderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/RuleClassBuilderTest.java
@@ -23,6 +23,8 @@
 
 import com.google.common.base.Predicate;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.devtools.build.lib.actions.MutableActionGraph.ActionConflictException;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassNamePredicate;
@@ -191,6 +193,54 @@
   }
 
   @Test
+  public void testExecGroupsAreInherited() throws Exception {
+    Label mockToolchainType = Label.parseAbsoluteUnchecked("//mock_toolchain_type");
+    Label mockConstraint = Label.parseAbsoluteUnchecked("//mock_constraint");
+    ExecGroup parentGroup =
+        new ExecGroup(ImmutableSet.of(mockToolchainType), ImmutableSet.of(mockConstraint));
+    ExecGroup childGroup = new ExecGroup(ImmutableSet.of(), ImmutableSet.of());
+    RuleClass parent =
+        new RuleClass.Builder("$parent", RuleClassType.ABSTRACT, false)
+            .add(attr("tags", STRING_LIST))
+            .addExecGroups(ImmutableMap.of("group", parentGroup))
+            .build();
+    RuleClass child =
+        new RuleClass.Builder("child", RuleClassType.NORMAL, false, parent)
+            .factory(DUMMY_CONFIGURED_TARGET_FACTORY)
+            .add(attr("attr", STRING))
+            .addExecGroups(ImmutableMap.of("child-group", childGroup))
+            .build();
+    assertThat(child.getExecGroups().get("group")).isEqualTo(parentGroup);
+    assertThat(child.getExecGroups().get("child-group")).isEqualTo(childGroup);
+  }
+
+  @Test
+  public void testDuplicateExecGroupNamesErrors() throws Exception {
+    RuleClass a =
+        new RuleClass.Builder("ruleA", RuleClassType.NORMAL, false)
+            .factory(DUMMY_CONFIGURED_TARGET_FACTORY)
+            .addExecGroups(
+                ImmutableMap.of("blueberry", new ExecGroup(ImmutableSet.of(), ImmutableSet.of())))
+            .add(attr("tags", STRING_LIST))
+            .build();
+    RuleClass b =
+        new RuleClass.Builder("ruleB", RuleClassType.NORMAL, false)
+            .factory(DUMMY_CONFIGURED_TARGET_FACTORY)
+            .addExecGroups(
+                ImmutableMap.of("blueberry", new ExecGroup(ImmutableSet.of(), ImmutableSet.of())))
+            .add(attr("tags", STRING_LIST))
+            .build();
+    IllegalArgumentException e =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> new RuleClass.Builder("ruleC", RuleClassType.NORMAL, false, a, b).build());
+    assertThat(e)
+        .hasMessageThat()
+        .isEqualTo(
+            "An execution group named 'blueberry' is inherited multiple times in ruleC ruleclass");
+  }
+
+  @Test
   public void testBasicRuleNamePredicates() throws Exception {
     Predicate<String> abcdef = nothingBut("abc", "def").asPredicateOfRuleClassName();
     assertThat(abcdef.test("abc")).isTrue();
diff --git a/src/test/java/com/google/devtools/build/lib/packages/RuleClassTest.java b/src/test/java/com/google/devtools/build/lib/packages/RuleClassTest.java
index 3153f5b..51697d2 100644
--- a/src/test/java/com/google/devtools/build/lib/packages/RuleClassTest.java
+++ b/src/test/java/com/google/devtools/build/lib/packages/RuleClassTest.java
@@ -941,6 +941,7 @@
         /*requiredToolchains=*/ ImmutableSet.of(),
         /*useToolchainResolution=*/ true,
         /* executionPlatformConstraints= */ ImmutableSet.of(),
+        /* execGroups= */ ImmutableMap.of(),
         OutputFile.Kind.FILE,
         ImmutableList.copyOf(attributes),
         /* buildSetting= */ null);
@@ -1151,6 +1152,29 @@
   }
 
   @Test
+  public void testExecGroups() throws Exception {
+    RuleClass.Builder ruleClassBuilder =
+        new RuleClass.Builder("ruleClass", RuleClassType.NORMAL, false)
+            .factory(DUMMY_CONFIGURED_TARGET_FACTORY)
+            .add(attr("tags", STRING_LIST));
+
+    Label toolchain = Label.parseAbsoluteUnchecked("//toolchain");
+    Label constraint = Label.parseAbsoluteUnchecked("//constraint");
+
+    ruleClassBuilder.addExecGroups(
+        ImmutableMap.of(
+            "cherry", new ExecGroup(ImmutableSet.of(toolchain), ImmutableSet.of(constraint))));
+
+    RuleClass ruleClass = ruleClassBuilder.build();
+
+    assertThat(ruleClass.getExecGroups()).hasSize(1);
+    assertThat(ruleClass.getExecGroups().get("cherry").getRequiredToolchains())
+        .containsExactly(toolchain);
+    assertThat(ruleClass.getExecGroups().get("cherry").getExecutionPlatformConstraints())
+        .containsExactly(constraint);
+  }
+
+  @Test
   public void testBuildSetting_createsDefaultAttribute() {
     RuleClass labelFlag =
         new RuleClass.Builder("label_flag", RuleClassType.NORMAL, false)
diff --git a/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleClassFunctionsTest.java b/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleClassFunctionsTest.java
index b15fc6d..62cf58e 100644
--- a/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleClassFunctionsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleClassFunctionsTest.java
@@ -1760,6 +1760,32 @@
   }
 
   @Test
+  public void testRuleAddExecGroup() throws Exception {
+    setSkylarkSemanticsOptions("--experimental_exec_groups=true");
+    reset();
+
+    registerDummyStarlarkFunction();
+    scratch.file("test/BUILD", "toolchain_type(name = 'my_toolchain_type')");
+    evalAndExport(
+        "plum = rule(",
+        "  implementation = impl,",
+        "  exec_groups = {",
+        "    'group': exec_group(",
+        "      toolchains=['//test:my_toolchain_type'],",
+        "      exec_compatible_with=['//constraint:cv1', '//constraint:cv2'],",
+        "    ),",
+        "  },",
+        ")");
+    RuleClass plum = ((SkylarkRuleFunction) lookup("plum")).getRuleClass();
+    assertThat(plum.getRequiredToolchains()).isEmpty();
+    assertThat(plum.getExecGroups().get("group").getRequiredToolchains())
+        .containsExactly(makeLabel("//test:my_toolchain_type"));
+    assertThat(plum.getExecutionPlatformConstraints()).isEmpty();
+    assertThat(plum.getExecGroups().get("group").getExecutionPlatformConstraints())
+        .containsExactly(makeLabel("//constraint:cv1"), makeLabel("//constraint:cv2"));
+  }
+
+  @Test
   public void testRuleFunctionReturnsNone() throws Exception {
     scratch.file("test/rule.bzl",
         "def _impl(ctx):",