Add a module extension for registering local/remote jdks (#312)

Closes #312

COPYBARA_INTEGRATE_REVIEW=https://github.com/bazelbuild/rules_java/pull/312 from bazelbuild:hvd_java_repos_module_ext d11e12e2eda5b3cfb5e1075d03f095ef5ebe3871
PiperOrigin-RevId: 796756819
Change-Id: I7d672490bcb47e64c6dbb2f5138d261bee9fa52d
diff --git a/java/extensions.bzl b/java/extensions.bzl
index f456f3f..ad7978b 100644
--- a/java/extensions.bzl
+++ b/java/extensions.bzl
@@ -23,6 +23,7 @@
     "remote_jdk21_repos",
     "remote_jdk8_repos",
 )
+load("//toolchains:extensions.bzl", _java_repository = "java_repository")
 
 def _toolchains_impl(module_ctx):
     java_tools_repos()
@@ -38,3 +39,5 @@
         return None
 
 toolchains = module_extension(_toolchains_impl)
+
+java_repository = _java_repository
diff --git a/test/repo/BUILD.bazel b/test/repo/BUILD.bazel
index ca41439..b0fb612 100644
--- a/test/repo/BUILD.bazel
+++ b/test/repo/BUILD.bazel
@@ -53,4 +53,9 @@
     name = "my_funky_toolchain",
     bootclasspath = ["@bazel_tools//tools/jdk:platformclasspath"],
     configuration = NONPREBUILT_TOOLCHAIN_CONFIGURATION,
+    exec_compatible_with = [
+        "@platforms//os:linux",
+        "@platforms//cpu:x86_64",
+    ],
+    java_runtime = "@my_funky_jdk//:jdk",
 )
diff --git a/test/repo/MODULE.bazel b/test/repo/MODULE.bazel
index 1faec90..6210933 100644
--- a/test/repo/MODULE.bazel
+++ b/test/repo/MODULE.bazel
@@ -49,6 +49,25 @@
     "remotejdk21_win",
 )
 
+custom_jdk = use_extension("@rules_java//java:extensions.bzl", "java_repository")
+custom_jdk.remote(
+    name = "my_funky_jdk",
+    prefix = "funky",
+    strip_prefix = "zulu24.32.13-ca-jdk24.0.2-linux_x64",
+    target_compatible_with = [
+        "@platforms//os:linux",
+        "@platforms//cpu:x86_64",
+    ],
+    urls = [
+        "https://cdn.azul.com/zulu/bin/zulu24.32.13-ca-jdk24.0.2-linux_x64.tar.gz",
+    ],
+    version = "24",
+)
+use_repo(custom_jdk, "my_funky_jdk", "my_funky_jdk_toolchain_config_repo")
+
+register_toolchains("@my_funky_jdk_toolchain_config_repo//:all")
+
 register_toolchains("//:all")
 
 bazel_dep(name = "rules_shell", version = "0.4.0", dev_dependency = True)
+bazel_dep(name = "platforms", version = "0.0.11", dev_dependency = True)
diff --git a/test/repo/WORKSPACE b/test/repo/WORKSPACE
index c272975..a6055e2 100644
--- a/test/repo/WORKSPACE
+++ b/test/repo/WORKSPACE
@@ -47,3 +47,21 @@
 rules_shell_dependencies()
 
 rules_shell_toolchains()
+
+load("@rules_java//toolchains:remote_java_repository.bzl", "remote_java_repository")
+
+remote_java_repository(
+    name = "my_funky_jdk",
+    prefix = "funky",
+    strip_prefix = "zulu24.32.13-ca-jdk24.0.2-linux_x64",
+    target_compatible_with = [
+        "@platforms//os:linux",
+        "@platforms//cpu:x86_64",
+    ],
+    urls = [
+        "https://cdn.azul.com/zulu/bin/zulu24.32.13-ca-jdk24.0.2-linux_x64.tar.gz",
+    ],
+    version = "24",
+)
+
+register_toolchains("@my_funky_jdk_toolchain_config_repo//:all")
diff --git a/toolchains/extensions.bzl b/toolchains/extensions.bzl
new file mode 100644
index 0000000..330c707
--- /dev/null
+++ b/toolchains/extensions.bzl
@@ -0,0 +1,59 @@
+"""Module extensions for local and remote java repositories"""
+
+load(":local_java_repository.bzl", "local_java_repository")
+load(":remote_java_repository.bzl", "remote_java_repository")
+
+visibility(["//java"])
+
+def _java_repository_impl(mctx):
+    for mod in mctx.modules:
+        if not mod.is_root:
+            fail(
+                """This module extension may only be used in the root module. {name}
+                must set `dev_dependency` = True on it's usage of this extension,
+                if {name} can be dependency of other modules.""".format(name = mod.name),
+            )
+        for local in mod.tags.local:
+            local_java_repository(
+                local.name,
+                java_home = local.java_home,
+                version = local.version,
+                build_file = local.build_file,
+                build_file_content = local.build_file_content,
+            )
+        for remote in mod.tags.remote:
+            remote_java_repository(
+                remote.name,
+                remote.version,
+                target_compatible_with = remote.target_compatible_with,
+                prefix = remote.prefix,
+                sha256 = remote.sha256,
+                strip_prefix = remote.strip_prefix,
+                urls = remote.urls,
+            )
+
+_local = tag_class(attrs = {
+    "name": attr.string(mandatory = True),
+    "build_file": attr.label(default = None),
+    "build_file_content": attr.string(default = ""),
+    "java_home": attr.string(default = ""),
+    "version": attr.string(default = ""),
+})
+
+_remote = tag_class(attrs = {
+    "name": attr.string(mandatory = True),
+    "version": attr.string(mandatory = True),
+    "urls": attr.string_list(mandatory = True),
+    "prefix": attr.string(default = ""),
+    "sha256": attr.string(default = ""),
+    "strip_prefix": attr.string(default = ""),
+    "target_compatible_with": attr.string_list(default = []),
+})
+
+java_repository = module_extension(
+    _java_repository_impl,
+    tag_classes = {
+        "local": _local,
+        "remote": _remote,
+    },
+)