Add automatic exec groups to cc_common.link

PiperOrigin-RevId: 514332514
Change-Id: I5ea3b2fa70a858c8d206a36ede1eec02fb344bf2
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java b/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
index 9ed5d52..d8dc6f5 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/AutoExecGroupsTest.java
@@ -16,6 +16,7 @@
 
 import static com.google.common.collect.ImmutableList.toImmutableList;
 import static com.google.common.truth.Truth.assertThat;
+import static com.google.devtools.build.lib.rules.cpp.CppRuleClasses.CPP_LINK_EXEC_GROUP;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -23,13 +24,19 @@
 import com.google.common.collect.ObjectArrays;
 import com.google.devtools.build.lib.actions.Action;
 import com.google.devtools.build.lib.analysis.actions.LazyWritePathsFileAction;
+import com.google.devtools.build.lib.analysis.actions.ParameterFileWriteAction;
 import com.google.devtools.build.lib.analysis.actions.SpawnAction;
 import com.google.devtools.build.lib.analysis.config.ToolchainTypeRequirement;
 import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
 import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
+import com.google.devtools.build.lib.analysis.util.AnalysisMock;
 import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.packages.ExecGroup;
+import com.google.devtools.build.lib.packages.util.Crosstool.CcToolchainConfig;
+import com.google.devtools.build.lib.rules.cpp.CppCompileAction;
+import com.google.devtools.build.lib.rules.cpp.CppLinkAction;
+import com.google.devtools.build.lib.rules.cpp.CppRuleClasses;
 import com.google.devtools.build.lib.rules.java.JavaGenJarsProvider;
 import com.google.devtools.build.lib.rules.java.JavaInfo;
 import com.google.devtools.build.lib.testutil.TestConstants;
@@ -1390,4 +1397,278 @@
     assertThat(actions.get(0).getOwner().getExecutionPlatform().label())
         .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
   }
+
+  @Test
+  public void ccCommonLink_fileWriteActionExecutesOnFirstPlatform() throws Exception {
+    scratch.file(
+        "test/defs.bzl",
+        "def _use_cpp_toolchain():",
+        "   return [",
+        "      config_common.toolchain_type('"
+            + TestConstants.CPP_TOOLCHAIN_TYPE
+            + "', mandatory = False),",
+        "   ]",
+        "def _impl(ctx):",
+        "  cc_toolchain = ctx.attr._cc_toolchain[cc_common.CcToolchainInfo]",
+        "  feature_configuration = cc_common.configure_features(",
+        "      ctx = ctx,",
+        "      cc_toolchain = cc_toolchain,",
+        "      requested_features = ctx.features,",
+        "     unsupported_features = ctx.disabled_features,",
+        "  )",
+        "  linking_outputs = cc_common.link(",
+        "    name = ctx.label.name,",
+        "    actions = ctx.actions,",
+        "    feature_configuration = feature_configuration,",
+        "    cc_toolchain = cc_toolchain,",
+        "  )",
+        "  return []",
+        "custom_rule = rule(",
+        "  implementation = _impl,",
+        "  attrs = { '_cc_toolchain': attr.label(default=Label('//test:alias')) },",
+        "  toolchains = ['//rule:toolchain_type_2'] + _use_cpp_toolchain(),",
+        "  fragments = ['cpp']",
+        ")");
+    scratch.file(
+        "test/BUILD",
+        "load('//test:defs.bzl', 'custom_rule')",
+        "cc_toolchain_alias(name='alias')",
+        "custom_rule(name = 'custom_rule_name')");
+    useConfiguration("--incompatible_auto_exec_groups");
+
+    ImmutableList<Action> actions =
+        getActions("//test:custom_rule_name", ParameterFileWriteAction.class);
+
+    assertThat(actions).hasSize(1);
+    assertThat(actions.get(0).getOwner().getExecutionPlatform().label())
+        .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
+  }
+
+  @Test
+  public void ccCommonLink_cppLinkExecGroupNotDefined_cppLinkActionExecutesOnFirstPlatform()
+      throws Exception {
+    scratch.file(
+        "test/defs.bzl",
+        "def _use_cpp_toolchain():",
+        "   return [",
+        "      config_common.toolchain_type('"
+            + TestConstants.CPP_TOOLCHAIN_TYPE
+            + "', mandatory = False),",
+        "   ]",
+        "def _impl(ctx):",
+        "  cc_toolchain = ctx.attr._cc_toolchain[cc_common.CcToolchainInfo]",
+        "  feature_configuration = cc_common.configure_features(",
+        "      ctx = ctx,",
+        "      cc_toolchain = cc_toolchain,",
+        "      requested_features = ctx.features,",
+        "     unsupported_features = ctx.disabled_features,",
+        "  )",
+        "  linking_outputs = cc_common.link(",
+        "    name = ctx.label.name,",
+        "    actions = ctx.actions,",
+        "    feature_configuration = feature_configuration,",
+        "    cc_toolchain = cc_toolchain,",
+        "  )",
+        "  return []",
+        "custom_rule = rule(",
+        "  implementation = _impl,",
+        "  attrs = { '_cc_toolchain': attr.label(default=Label('//test:alias')) },",
+        "  exec_groups = { ",
+        "    '" + CPP_LINK_EXEC_GROUP + "': exec_group(toolchains = _use_cpp_toolchain()),",
+        "  },",
+        "  toolchains = ['//rule:toolchain_type_2'] + _use_cpp_toolchain(),",
+        "  fragments = ['cpp']",
+        ")");
+    scratch.file(
+        "test/BUILD",
+        "load('//test:defs.bzl', 'custom_rule')",
+        "cc_toolchain_alias(name='alias')",
+        "custom_rule(name = 'custom_rule_name')");
+    useConfiguration("--incompatible_auto_exec_groups");
+
+    ImmutableList<Action> actions = getActions("//test:custom_rule_name", CppLinkAction.class);
+
+    assertThat(actions).hasSize(1);
+    assertThat(actions.get(0).getMnemonic()).isEqualTo("CppLink");
+    assertThat(actions.get(0).getOwner().getExecutionPlatform().label())
+        .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
+  }
+
+  @Test
+  public void ccCommonLink_cppLinkExecGroupDefined_cppLinkActionExecutesOnFirstPlatform()
+      throws Exception {
+    scratch.file(
+        "test/defs.bzl",
+        "def _use_cpp_toolchain():",
+        "   return [",
+        "      config_common.toolchain_type('"
+            + TestConstants.CPP_TOOLCHAIN_TYPE
+            + "', mandatory = False),",
+        "   ]",
+        "def _impl(ctx):",
+        "  cc_toolchain = ctx.attr._cc_toolchain[cc_common.CcToolchainInfo]",
+        "  feature_configuration = cc_common.configure_features(",
+        "      ctx = ctx,",
+        "      cc_toolchain = cc_toolchain,",
+        "      requested_features = ctx.features,",
+        "     unsupported_features = ctx.disabled_features,",
+        "  )",
+        "  linking_outputs = cc_common.link(",
+        "    name = ctx.label.name,",
+        "    actions = ctx.actions,",
+        "    feature_configuration = feature_configuration,",
+        "    cc_toolchain = cc_toolchain,",
+        "  )",
+        "  return []",
+        "custom_rule = rule(",
+        "  implementation = _impl,",
+        "  attrs = { '_cc_toolchain': attr.label(default=Label('//test:alias')) },",
+        "  exec_groups = { ",
+        "    '" + CPP_LINK_EXEC_GROUP + "': exec_group(toolchains = _use_cpp_toolchain()),",
+        "  },",
+        "  toolchains = ['//rule:toolchain_type_2'] + _use_cpp_toolchain(),",
+        "  fragments = ['cpp']",
+        ")");
+    scratch.file(
+        "test/BUILD",
+        "load('//test:defs.bzl', 'custom_rule')",
+        "cc_toolchain_alias(name='alias')",
+        "custom_rule(name = 'custom_rule_name')");
+    useConfiguration("--incompatible_auto_exec_groups");
+
+    ImmutableList<Action> actions = getActions("//test:custom_rule_name", CppLinkAction.class);
+
+    assertThat(actions).hasSize(1);
+    assertThat(actions.get(0).getMnemonic()).isEqualTo("CppLink");
+    assertThat(actions.get(0).getOwner().getExecutionPlatform().label())
+        .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
+  }
+
+  @Test
+  public void ccCommonLink_cppLTOActionExecutesOnFirstPlatform() throws Exception {
+    scratch.file(
+        "test/defs.bzl",
+        "def _use_cpp_toolchain():",
+        "   return [",
+        "      config_common.toolchain_type('"
+            + TestConstants.CPP_TOOLCHAIN_TYPE
+            + "', mandatory = False),",
+        "   ]",
+        "def _impl(ctx):",
+        "  cc_toolchain = ctx.attr._cc_toolchain[cc_common.CcToolchainInfo]",
+        "  feature_configuration = cc_common.configure_features(",
+        "      ctx = ctx,",
+        "      cc_toolchain = cc_toolchain,",
+        "      requested_features = ctx.features,",
+        "      unsupported_features = ctx.disabled_features,",
+        "  )",
+        "  linking_outputs = cc_common.link(",
+        "    name = ctx.label.name,",
+        "    actions = ctx.actions,",
+        "    feature_configuration = feature_configuration,",
+        "    cc_toolchain = cc_toolchain,",
+        "    linking_contexts = [dep[CcInfo].linking_context for dep in ctx.attr.deps if"
+            + " CcInfo in dep]",
+        "  )",
+        "  return []",
+        "custom_rule = rule(",
+        "  implementation = _impl,",
+        "  attrs = {",
+        "    'deps': attr.label_list(),",
+        "    'srcs': attr.label_list(allow_files = ['.cc']),",
+        "    '_cc_toolchain': attr.label(default=Label('//test:alias')),",
+        "  },",
+        "  toolchains = ['//rule:toolchain_type_2'] + _use_cpp_toolchain(),",
+        "  fragments = ['cpp']",
+        ")");
+    scratch.file(
+        "test/BUILD",
+        "load('//test:defs.bzl', 'custom_rule')",
+        "cc_toolchain_alias(name='alias')",
+        "cc_library(",
+        "  name = 'dep',",
+        "  srcs = ['dep.cc'],",
+        ")",
+        "custom_rule(",
+        "  name = 'custom_rule_name',",
+        "  srcs = ['custom.cc'],",
+        "  deps = ['dep'],",
+        ")");
+    useConfiguration("--incompatible_auto_exec_groups", "--features=thin_lto");
+    AnalysisMock.get()
+        .ccSupport()
+        .setupCcToolchainConfig(
+            mockToolsConfig,
+            CcToolchainConfig.builder()
+                .withFeatures(CppRuleClasses.THIN_LTO, CppRuleClasses.SUPPORTS_START_END_LIB));
+
+    ImmutableList<Action> actions = getActions("//test:custom_rule_name", CppLinkAction.class);
+    ImmutableList<Action> cppLTOActions =
+        actions.stream()
+            .filter(action -> action.getMnemonic().equals("CppLTOIndexing"))
+            .collect(toImmutableList());
+
+    assertThat(cppLTOActions).hasSize(1);
+    assertThat(cppLTOActions.get(0).getOwner().getExecutionPlatform().label())
+        .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
+  }
+
+  @Test
+  public void ccCommonLink_linkstampCompileActionExecutesOnFirstPlatform() throws Exception {
+    scratch.file(
+        "bazel_internal/test_rules/cc/defs.bzl",
+        "def _use_cpp_toolchain():",
+        "   return [",
+        "      config_common.toolchain_type('"
+            + TestConstants.CPP_TOOLCHAIN_TYPE
+            + "', mandatory = False),",
+        "   ]",
+        "def _impl(ctx):",
+        "  cc_toolchain = ctx.attr._cc_toolchain[cc_common.CcToolchainInfo]",
+        "  feature_configuration = cc_common.configure_features(",
+        "      ctx = ctx,",
+        "      cc_toolchain = cc_toolchain,",
+        "      requested_features = ctx.features,",
+        "      unsupported_features = ctx.disabled_features,",
+        "  )",
+        "  linking_outputs = cc_common.link(",
+        "    name = ctx.label.name,",
+        "    actions = ctx.actions,",
+        "    feature_configuration = feature_configuration,",
+        "    cc_toolchain = cc_toolchain,",
+        "    linking_contexts = [dep[CcInfo].linking_context for dep in ctx.attr.deps if"
+            + " CcInfo in dep]",
+        "  )",
+        "  return []",
+        "custom_rule = rule(",
+        "  implementation = _impl,",
+        "  attrs = {",
+        "    'deps': attr.label_list(),",
+        "    '_cc_toolchain': attr.label(default=Label('//bazel_internal/test_rules/cc:alias')),",
+        "  },",
+        "  toolchains = ['//rule:toolchain_type_2'] + _use_cpp_toolchain(),",
+        "  fragments = ['cpp']",
+        ")");
+    scratch.file(
+        "bazel_internal/test_rules/cc/BUILD",
+        "load('//bazel_internal/test_rules/cc:defs.bzl', 'custom_rule')",
+        "cc_toolchain_alias(name='alias')",
+        "cc_library(",
+        "  name = 'dep',",
+        "  linkstamp = 'stamp.cc',",
+        ")",
+        "custom_rule(",
+        "  name = 'custom_rule_name',",
+        "  deps = ['dep'],",
+        ")");
+    useConfiguration("--incompatible_auto_exec_groups");
+
+    ImmutableList<Action> cppCompileActions =
+        getActions("//bazel_internal/test_rules/cc:custom_rule_name", CppCompileAction.class);
+
+    assertThat(cppCompileActions).hasSize(1);
+    assertThat(cppCompileActions.get(0).getMnemonic()).isEqualTo("CppLinkstampCompile");
+    assertThat(cppCompileActions.get(0).getOwner().getExecutionPlatform().label())
+        .isEqualTo(Label.parseCanonical("//platforms:platform_1"));
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/MockCppSemantics.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/MockCppSemantics.java
index 8190bcd..0e00c45 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/cpp/MockCppSemantics.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/MockCppSemantics.java
@@ -45,6 +45,13 @@
     return Language.CPP;
   }
 
+  private static final String CPP_TOOLCHAIN_TYPE = "@bazel_tools//tools/cpp:toolchain_type";
+
+  @Override
+  public String getCppToolchainType() {
+    return CPP_TOOLCHAIN_TYPE;
+  }
+
   @Override
   public void finalizeCompileActionBuilder(
       BuildConfigurationValue configuration,
diff --git a/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java b/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
index a4967e3..aa1bf27 100644
--- a/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
+++ b/src/test/java/com/google/devtools/build/lib/testutil/TestConstants.java
@@ -167,6 +167,9 @@
   /** The java toolchain type. */
   public static final String JAVA_TOOLCHAIN_TYPE = "@bazel_tools//tools/jdk:toolchain_type";
 
+  /** The cpp toolchain type. */
+  public static final String CPP_TOOLCHAIN_TYPE = "@bazel_tools//tools/cpp:toolchain_type";
+
   /** A choice of test execution mode, only varies internally. */
   public enum InternalTestExecutionMode {
     NORMAL