Create a `java_common.compile_header` API that just does the header compilation part of `java_common.compile`

PiperOrigin-RevId: 921992069
Change-Id: I4dfc3671c04175e2a803427dab64a7db2a7d180c
diff --git a/java/private/java_common.bzl b/java/private/java_common.bzl
index 174f4e6..59ba7df 100644
--- a/java/private/java_common.bzl
+++ b/java/private/java_common.bzl
@@ -22,6 +22,7 @@
 load(":boot_class_path_info.bzl", "BootClassPathInfo")
 load(
     ":java_common_internal.bzl",
+    _compile_header_internal = "compile_header",
     _compile_internal = "compile",
     _run_ijar_internal = "run_ijar",
 )
@@ -89,6 +90,32 @@
         enable_annotation_processing = enable_annotation_processing,
     )
 
+def _compile_header(
+        ctx,
+        output,
+        java_toolchain,
+        source_jars = [],
+        source_files = [],
+        javac_opts = [],
+        deps = [],
+        plugins = [],
+        strict_deps = "ERROR",
+        bootclasspath = None,
+        enable_annotation_processing = True):
+    return _compile_header_internal(
+        ctx,
+        output = output,
+        java_toolchain = java_toolchain,
+        source_jars = source_jars,
+        source_files = source_files,
+        javac_opts = javac_opts,
+        deps = deps,
+        plugins = plugins,
+        strict_deps = strict_deps,
+        bootclasspath = bootclasspath,
+        enable_annotation_processing = enable_annotation_processing,
+    )
+
 def _run_ijar(actions, jar, java_toolchain, target_label = None):
     get_internal_java_common().check_java_toolchain_is_declared_on_rule(actions)
     return _run_ijar_internal(
@@ -277,6 +304,7 @@
     methods = {
         "provider": JavaInfo,
         "compile": _compile,
+        "compile_header": _compile_header,
         "run_ijar": _run_ijar,
         "stamp_jar": _stamp_jar,
         "pack_sources": _pack_sources,
diff --git a/java/private/java_common_internal.bzl b/java/private/java_common_internal.bzl
index 249b925..30d9533 100644
--- a/java/private/java_common_internal.bzl
+++ b/java/private/java_common_internal.bzl
@@ -37,6 +37,103 @@
     "DEFAULT",  # When no flag value is specified on the command line.
 ]
 
+def _construct_javac_opts(ctx, java_toolchain, plugin_info, javac_opts, bootclasspath, add_exports = []):
+    all_javac_opts = []  # [depset[str]]
+    all_javac_opts.append(java_toolchain._javacopts)
+    all_javac_opts.append(ctx.fragments.java.default_javac_flags_depset)
+    all_javac_opts.append(semantics.compatible_javac_options(ctx, java_toolchain))
+
+    if ("com.google.devtools.build.runfiles.AutoBazelRepositoryProcessor" in
+        plugin_info.plugins.processor_classes.to_list()):
+        all_javac_opts.append(depset(
+            ["-Abazel.repository=" + ctx.label.repo_name],
+            order = "preorder",
+        ))
+    system_bootclasspath = None
+    for package_config in java_toolchain._package_configuration:
+        if package_config.matches(package_config.package_specs, ctx.label):
+            all_javac_opts.append(package_config.javac_opts)
+            if package_config.system:
+                if system_bootclasspath:
+                    fail("Multiple system package configurations found for %s" % ctx.label)
+                system_bootclasspath = package_config.system
+    if not bootclasspath:
+        bootclasspath = system_bootclasspath
+
+    all_javac_opts.append(depset(
+        ["--add-exports=%s=ALL-UNNAMED" % x for x in add_exports],
+        order = "preorder",
+    ))
+
+    if type(javac_opts) == type([]):
+        # detokenize target's javacopts, it will be tokenized before compilation
+        all_javac_opts.append(helper.detokenize_javacopts(helper.tokenize_javacopts(ctx, javac_opts)))
+    elif type(javac_opts) == type(depset()):
+        all_javac_opts.append(javac_opts)
+    else:
+        fail("Expected javac_opts to be a list or depset, got:", type(javac_opts))
+
+    # we reverse the list of javacopts depsets, so that we keep the right-most set
+    # in case it's deduped. When this depset is flattened, we will reverse again,
+    # and then tokenize before passing to javac. This way, right-most javacopts will
+    # be retained and "win out".
+    return depset(order = "preorder", transitive = reversed(all_javac_opts)), bootclasspath
+
+def _construct_classpaths(deps, strict_deps, classpath_mode):
+    is_strict_mode = strict_deps != "OFF"
+
+    direct_jars = depset()
+    if is_strict_mode:
+        direct_jars = depset(order = "preorder", transitive = [dep.compile_jars for dep in deps])
+
+    header_compilation_direct_deps = depset()
+    if is_strict_mode:
+        header_compilation_direct_deps = depset(
+            order = "preorder",
+            transitive = [dep.header_compilation_direct_deps for dep in deps],
+        )
+
+    compilation_classpath = depset(
+        order = "preorder",
+        transitive = [direct_jars] + [dep.transitive_compile_time_jars for dep in deps],
+    )
+    compile_time_java_deps = depset()
+    if is_strict_mode and classpath_mode != "OFF":
+        compile_time_java_deps = depset(transitive = [dep._compile_time_java_dependencies for dep in deps])
+
+    return struct(
+        direct_jars = direct_jars,
+        header_compilation_direct_deps = header_compilation_direct_deps,
+        compilation_classpath = compilation_classpath,
+        compile_time_java_deps = compile_time_java_deps,
+    )
+
+def _derive_header_compilation_outputs(ctx, base_output, suffix = ""):
+    if suffix:
+        compile_jar = _derive_output_file(ctx, base_output, name_suffix = suffix, extension = "jar")
+        compile_deps_proto = _derive_output_file(ctx, base_output, name_suffix = suffix, extension = "jdeps")
+    else:
+        compile_jar = base_output
+        compile_deps_proto = _derive_output_file(ctx, base_output, extension = "jdeps")
+
+    # TODO: b/417791104 - remove check after a Bazel release
+    if ctx.fragments.java.use_header_compilation_direct_deps():
+        header_compilation_jar = _derive_output_file(ctx, base_output, name_suffix = "-tjar", extension = "jar")
+    else:
+        header_compilation_jar = None
+
+    return struct(
+        compile_jar = compile_jar,
+        header_compilation_jar = header_compilation_jar,
+        compile_deps_proto = compile_deps_proto,
+    )
+
+def _validate_strict_deps(strict_deps):
+    strict_deps = (strict_deps or "default").upper()
+    if strict_deps not in _STRICT_DEPS_VALUES:
+        fail("Got an invalid value for strict_deps:", strict_deps, "must be one of:", _STRICT_DEPS_VALUES)
+    return strict_deps
+
 def compile(
         ctx,
         output,
@@ -125,54 +222,18 @@
     get_internal_java_common().check_provider_instances([java_toolchain], "java_toolchain", JavaToolchainInfo)
     get_internal_java_common().check_provider_instances(plugins, "plugins", JavaPluginInfo)
 
-    # normalize and validate strict_deps
-    strict_deps = (strict_deps or "default").upper()
-    if strict_deps not in _STRICT_DEPS_VALUES:
-        fail("Got an invalid value for strict_deps:", strict_deps, "must be one of:", _STRICT_DEPS_VALUES)
+    strict_deps = _validate_strict_deps(strict_deps)
 
     plugin_info = merge_plugin_info_without_outputs(plugins + deps)
 
-    all_javac_opts = []  # [depset[str]]
-    all_javac_opts.append(java_toolchain._javacopts)
-
-    all_javac_opts.append(ctx.fragments.java.default_javac_flags_depset)
-    all_javac_opts.append(semantics.compatible_javac_options(ctx, java_toolchain))
-
-    if ("com.google.devtools.build.runfiles.AutoBazelRepositoryProcessor" in
-        plugin_info.plugins.processor_classes.to_list()):
-        all_javac_opts.append(depset(
-            ["-Abazel.repository=" + ctx.label.repo_name],
-            order = "preorder",
-        ))
-    system_bootclasspath = None
-    for package_config in java_toolchain._package_configuration:
-        if package_config.matches(package_config.package_specs, ctx.label):
-            all_javac_opts.append(package_config.javac_opts)
-            if package_config.system:
-                if system_bootclasspath:
-                    fail("Multiple system package configurations found for %s" % ctx.label)
-                system_bootclasspath = package_config.system
-    if not bootclasspath:
-        bootclasspath = system_bootclasspath
-
-    all_javac_opts.append(depset(
-        ["--add-exports=%s=ALL-UNNAMED" % x for x in add_exports],
-        order = "preorder",
-    ))
-
-    if type(javac_opts) == type([]):
-        # detokenize target's javacopts, it will be tokenized before compilation
-        all_javac_opts.append(helper.detokenize_javacopts(helper.tokenize_javacopts(ctx, javac_opts)))
-    elif type(javac_opts) == type(depset()):
-        all_javac_opts.append(javac_opts)
-    else:
-        fail("Expected javac_opts to be a list or depset, got:", type(javac_opts))
-
-    # we reverse the list of javacopts depsets, so that we keep the right-most set
-    # in case it's deduped. When this depset is flattened, we will reverse again,
-    # and then tokenize before passing to javac. This way, right-most javacopts will
-    # be retained and "win out".
-    all_javac_opts = depset(order = "preorder", transitive = reversed(all_javac_opts))
+    all_javac_opts, bootclasspath = _construct_javac_opts(
+        ctx,
+        java_toolchain,
+        plugin_info,
+        javac_opts,
+        bootclasspath,
+        add_exports,
+    )
 
     # Optimization: skip this if there are no annotation processors, to avoid unnecessarily
     # disabling the direct classpath optimization if `enable_annotation_processor = False`
@@ -190,27 +251,8 @@
     has_sources = source_files or source_jars
     has_resources = resources or resource_jars
 
-    is_strict_mode = strict_deps != "OFF"
     classpath_mode = ctx.fragments.java.reduce_java_classpath()
-
-    direct_jars = depset()
-    if is_strict_mode:
-        direct_jars = depset(order = "preorder", transitive = [dep.compile_jars for dep in deps])
-
-    header_compilation_direct_deps = depset()
-    if is_strict_mode:
-        header_compilation_direct_deps = depset(
-            order = "preorder",
-            transitive = [dep.header_compilation_direct_deps for dep in deps],
-        )
-
-    compilation_classpath = depset(
-        order = "preorder",
-        transitive = [direct_jars] + [dep.transitive_compile_time_jars for dep in deps],
-    )
-    compile_time_java_deps = depset()
-    if is_strict_mode and classpath_mode != "OFF":
-        compile_time_java_deps = depset(transitive = [dep._compile_time_java_dependencies for dep in deps])
+    classpaths = _construct_classpaths(deps, strict_deps, classpath_mode)
 
     # create compile time jar action
     if not has_sources:
@@ -222,14 +264,10 @@
         header_compilation_jar = compile_jar
         compile_deps_proto = None
     elif _should_use_header_compilation(ctx, java_toolchain):
-        compile_jar = _derive_output_file(ctx, output, name_suffix = "-hjar", extension = "jar")
-
-        # TODO: b/417791104 - remove check after a Bazel release
-        if ctx.fragments.java.use_header_compilation_direct_deps():
-            header_compilation_jar = _derive_output_file(ctx, output, name_suffix = "-tjar", extension = "jar")
-        else:
-            header_compilation_jar = None
-        compile_deps_proto = _derive_output_file(ctx, output, name_suffix = "-hjar", extension = "jdeps")
+        hdr_outputs = _derive_header_compilation_outputs(ctx, output, suffix = "-hjar")
+        compile_jar = hdr_outputs.compile_jar
+        header_compilation_jar = hdr_outputs.header_compilation_jar
+        compile_deps_proto = hdr_outputs.compile_deps_proto
         get_internal_java_common().create_header_compilation_action(
             ctx,
             java_toolchain,
@@ -238,10 +276,10 @@
             plugin_info,
             depset(source_files),
             source_jars,
-            compilation_classpath,
-            direct_jars,
+            classpaths.compilation_classpath,
+            classpaths.direct_jars,
             bootclasspath,
-            compile_time_java_deps,
+            classpaths.compile_time_java_deps,
             all_javac_opts,
             strict_deps,
             ctx.label,
@@ -249,7 +287,7 @@
             enable_direct_classpath,
             annotation_processor_additional_inputs,
             header_compilation_jar,
-            header_compilation_direct_deps,
+            classpaths.header_compilation_direct_deps,
         )
     elif ctx.fragments.java.use_ijars():
         compile_jar = run_ijar(
@@ -282,11 +320,11 @@
         output,
         manifest_proto,
         plugin_info,
-        compilation_classpath,
-        direct_jars,
+        classpaths.compilation_classpath,
+        classpaths.direct_jars,
         bootclasspath,
         depset(javabuilder_jvm_flags),
-        compile_time_java_deps,
+        classpaths.compile_time_java_deps,
         all_javac_opts,
         strict_deps,
         ctx.label,
@@ -331,7 +369,7 @@
         # needs to be flattened because the public API is a list
         boot_classpath = (bootclasspath.bootclasspath if bootclasspath else java_toolchain.bootclasspath).to_list(),
         # we only add compile time jars from deps, and not exports
-        compilation_classpath = compilation_classpath,
+        compilation_classpath = classpaths.compilation_classpath,
         runtime_classpath = depset(
             order = "preorder",
             direct = direct_runtime_jars,
@@ -494,3 +532,108 @@
         jars,
         excluded_jars,
     )
+
+def compile_header(
+        ctx,
+        output,
+        java_toolchain,
+        source_jars = [],
+        source_files = [],
+        javac_opts = [],
+        deps = [],
+        plugins = [],
+        strict_deps = "ERROR",
+        bootclasspath = None,
+        injecting_rule_kind = None,
+        enable_annotation_processing = True):
+    """Compiles Java header jars from the implementation of a Starlark rule.
+
+    Args:
+        ctx: (RuleContext) The rule context
+        output: (File) The output header jar (hjar)
+        java_toolchain: (JavaToolchainInfo) Toolchain to be used. Mandatory.
+        source_jars: ([File]) A list of the jars to be compiled.
+        source_files: ([File]) A list of the Java source files to be compiled.
+        javac_opts: ([str]|depset[str]) A list of the desired javac options. Optional.
+        deps: ([JavaInfo]) A list of dependencies. Optional.
+        plugins: ([JavaPluginInfo|JavaInfo]) A list of plugins. Optional.
+        strict_deps: (str) A string that specifies how to handle strict deps. Possible values:
+            'OFF', 'ERROR', 'WARN' and 'DEFAULT'.
+        bootclasspath: (BootClassPathInfo) If present, overrides the bootclasspath associated with
+            the provided java_toolchain. Optional.
+        injecting_rule_kind: (str|None)
+        enable_annotation_processing: (bool)
+
+    Returns:
+        (JavaInfo)
+    """
+    get_internal_java_common().check_provider_instances([java_toolchain], "java_toolchain", JavaToolchainInfo)
+    get_internal_java_common().check_provider_instances(plugins, "plugins", JavaPluginInfo)
+
+    strict_deps = _validate_strict_deps(strict_deps)
+
+    plugin_info = merge_plugin_info_without_outputs(plugins + deps)
+
+    all_javac_opts, bootclasspath = _construct_javac_opts(
+        ctx,
+        java_toolchain,
+        plugin_info,
+        javac_opts,
+        bootclasspath,
+        add_exports = [],
+    )
+
+    enable_direct_classpath = True
+    if not enable_annotation_processing and plugin_info.plugins.processor_classes:
+        plugin_info = disable_plugin_info_annotation_processing(plugin_info)
+        enable_direct_classpath = False
+
+    classpaths = _construct_classpaths(deps, strict_deps, ctx.fragments.java.reduce_java_classpath())
+
+    hdr_outputs = _derive_header_compilation_outputs(ctx, output)
+
+    get_internal_java_common().create_header_compilation_action(
+        ctx,
+        java_toolchain,
+        hdr_outputs.compile_jar,
+        hdr_outputs.compile_deps_proto,
+        plugin_info,
+        depset(source_files),
+        source_jars,
+        classpaths.compilation_classpath,
+        classpaths.direct_jars,
+        bootclasspath,
+        classpaths.compile_time_java_deps,
+        all_javac_opts,
+        strict_deps,
+        ctx.label,
+        injecting_rule_kind,
+        enable_direct_classpath,
+        [],  # additional_inputs
+        hdr_outputs.header_compilation_jar,
+        classpaths.header_compilation_direct_deps,
+    )
+
+    return java_info_for_compilation(
+        output_jar = hdr_outputs.compile_jar,
+        compile_jar = hdr_outputs.compile_jar,
+        header_compilation_jar = hdr_outputs.header_compilation_jar,
+        source_jar = None,
+        generated_class_jar = None,
+        generated_source_jar = None,
+        plugin_info = plugin_info,
+        deps = deps,
+        runtime_deps = [],
+        exports = [],
+        exported_plugins = [],
+        compile_jdeps = hdr_outputs.compile_deps_proto,
+        jdeps = None,
+        native_headers_jar = None,
+        manifest_proto = None,
+        native_libraries = [],
+        neverlink = True,
+        add_exports = [],
+        add_opens = [],
+        direct_runtime_jars = [],
+        compilation_info = None,
+    )
diff --git a/java/private/java_info.bzl b/java/private/java_info.bzl
index 19c435a..e873ee1 100644
--- a/java/private/java_info.bzl
+++ b/java/private/java_info.bzl
@@ -491,7 +491,7 @@
         runtime_output_jars = direct_runtime_jars,
         transitive_runtime_jars = transitive_runtime_jars,
         transitive_source_jars = depset(
-            direct = [source_jar],
+            direct = [source_jar] if source_jar else [],
             # only differs from the usual java_info.transitive_source_jars in the order of deps
             transitive = [dep.transitive_source_jars for dep in concatenated_deps.runtimedeps_exports_deps],
         ),