Allow tree artifacts to be source or header inputs to cc_common.compile()

This is already supported in cc_library, and would make the behavior
more consistent.

PiperOrigin-RevId: 362295588
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationHelper.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationHelper.java
index 025a20b..a6b17ae 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationHelper.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCompilationHelper.java
@@ -463,7 +463,9 @@
   }
 
   private CcCompilationHelper addPrivateHeader(Artifact privateHeader, Label label) {
-    boolean isHeader = CppFileTypes.CPP_HEADER.matches(privateHeader.getExecPath());
+    boolean isHeader =
+        CppFileTypes.CPP_HEADER.matches(privateHeader.getExecPath())
+            || privateHeader.isTreeArtifact();
     boolean isTextualInclude =
         CppFileTypes.CPP_TEXTUAL_INCLUDE.matches(privateHeader.getExecPath());
     Preconditions.checkState(isHeader || isTextualInclude);
diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcModule.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcModule.java
index a4243f0..f39891a 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcModule.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcModule.java
@@ -1985,17 +1985,20 @@
             CppFileTypes.CPP_SOURCE,
             CppFileTypes.C_SOURCE,
             CppFileTypes.ASSEMBLER_WITH_C_PREPROCESSOR,
-            CppFileTypes.ASSEMBLER));
+            CppFileTypes.ASSEMBLER),
+        /* allowAnyTreeArtifacts= */ true);
     validateExtensions(
         "public_hdrs",
         publicHeaders,
         FileTypeSet.of(CppFileTypes.CPP_HEADER),
-        FileTypeSet.of(CppFileTypes.CPP_HEADER));
+        FileTypeSet.of(CppFileTypes.CPP_HEADER),
+        /* allowAnyTreeArtifacts= */ true);
     validateExtensions(
         "private_hdrs",
         privateHeaders,
         FileTypeSet.of(CppFileTypes.CPP_HEADER),
-        FileTypeSet.of(CppFileTypes.CPP_HEADER));
+        FileTypeSet.of(CppFileTypes.CPP_HEADER),
+        /* allowAnyTreeArtifacts= */ true);
 
     if (disallowNopicOutputs && disallowPicOutputs) {
       throw Starlark.errorf("Either PIC or no PIC actions have to be created.");
@@ -2244,11 +2247,20 @@
       Object objectsObject, Object picObjectsObject) throws EvalException {
     CcCompilationOutputs.Builder ccCompilationOutputsBuilder = CcCompilationOutputs.builder();
     NestedSet<Artifact> objects = convertToNestedSet(objectsObject, Artifact.class, "objects");
-    validateExtensions("objects", objects.toList(), Link.OBJECT_FILETYPES, Link.OBJECT_FILETYPES);
+    validateExtensions(
+        "objects",
+        objects.toList(),
+        Link.OBJECT_FILETYPES,
+        Link.OBJECT_FILETYPES,
+        /* allowAnyTreeArtifacts= */ false);
     NestedSet<Artifact> picObjects =
         convertToNestedSet(picObjectsObject, Artifact.class, "pic_objects");
     validateExtensions(
-        "pic_objects", picObjects.toList(), Link.OBJECT_FILETYPES, Link.OBJECT_FILETYPES);
+        "pic_objects",
+        picObjects.toList(),
+        Link.OBJECT_FILETYPES,
+        Link.OBJECT_FILETYPES,
+        /* allowAnyTreeArtifacts= */ false);
     ccCompilationOutputsBuilder.addObjectFiles(objects.toList());
     ccCompilationOutputsBuilder.addPicObjectFiles(picObjects.toList());
     return ccCompilationOutputsBuilder.build();
@@ -2258,9 +2270,13 @@
       String paramName,
       List<Artifact> files,
       FileTypeSet validFileTypeSet,
-      FileTypeSet fileTypeForErrorMessage)
+      FileTypeSet fileTypeForErrorMessage,
+      boolean allowAnyTreeArtifacts)
       throws EvalException {
     for (Artifact file : files) {
+      if (allowAnyTreeArtifacts && file.isTreeArtifact()) {
+        continue;
+      }
       if (!validFileTypeSet.matches(file.getFilename())) {
         throw Starlark.errorf(
             "'%s' has wrong extension. The list of possible extensions for '%s' is: %s",
diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java
index e1718e1..5db4c53 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/StarlarkCcCommonTest.java
@@ -5587,6 +5587,21 @@
   }
 
   @Test
+  public void testTreeArtifactSrcs() throws Exception {
+    doTestTreeAtrifactInSrcsAndHdrs("srcs");
+  }
+
+  @Test
+  public void testTreeArtifactPrivateHdrs() throws Exception {
+    doTestTreeAtrifactInSrcsAndHdrs("private_hdrs");
+  }
+
+  @Test
+  public void testTreeArtifactPublicHdrs() throws Exception {
+    doTestTreeAtrifactInSrcsAndHdrs("public_hdrs");
+  }
+
+  @Test
   public void testWrongSrcsExtensionGivesError() throws Exception {
     doTestWrongExtensionOfSrcsAndHdrs("srcs");
   }
@@ -6120,6 +6135,35 @@
     }
   }
 
+  private void doTestTreeAtrifactInSrcsAndHdrs(String attrName) throws Exception {
+    createFiles(scratch, "tools/build_defs/foo");
+    reporter.removeHandler(failFastHandler);
+
+    scratch.file(
+        "bar/create_tree_artifact.bzl",
+        "def _impl(ctx):",
+        "    tree = ctx.actions.declare_directory('dir')",
+        "    ctx.actions.run_shell(",
+        "        outputs = [tree],",
+        "        inputs = [],",
+        "        arguments = [tree.path],",
+        "        command = 'mkdir $1',",
+        "    )",
+        "    return [DefaultInfo(files = depset([tree]))]",
+        "create_tree_artifact = rule(implementation = _impl)");
+    scratch.file(
+        "bar/BUILD",
+        "load('//tools/build_defs/foo:extension.bzl', 'cc_starlark_library')",
+        "load(':create_tree_artifact.bzl', 'create_tree_artifact')",
+        "create_tree_artifact(name = 'tree_artifact')",
+        "cc_starlark_library(",
+        "    name = 'starlark_lib',",
+        "    " + attrName + " = [':tree_artifact'],",
+        ")");
+    getConfiguredTarget("//bar:starlark_lib");
+    assertNoEvents();
+  }
+
   private void doTestCcOutputsWrongExtension(String attrName, String paramName) throws Exception {
     setupCcOutputsTest();
     scratch.file(