Implement flag_group in the new rule-based toolchain.

BEGIN_PUBLIC
Implement flag_group in the new rule-based toolchain.
END_PUBLIC

PiperOrigin-RevId: 622107179
Change-Id: I9e1971e279f313ce85537c899bcf80860616f8b7
diff --git a/cc/toolchains/args.bzl b/cc/toolchains/args.bzl
index 29e3a1b..1df3333 100644
--- a/cc/toolchains/args.bzl
+++ b/cc/toolchains/args.bzl
@@ -13,7 +13,7 @@
 # limitations under the License.
 """All providers for rule-based bazel toolchain config."""
 
-load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
+load("//cc/toolchains/impl:args_utils.bzl", "validate_nested_args")
 load(
     "//cc/toolchains/impl:collect.bzl",
     "collect_action_types",
@@ -21,35 +21,42 @@
     "collect_provider",
 )
 load(
+    "//cc/toolchains/impl:nested_args.bzl",
+    "NESTED_ARGS_ATTRS",
+    "args_wrapper_macro",
+    "nested_args_provider_from_ctx",
+)
+load(
     ":cc_toolchain_info.bzl",
     "ActionTypeSetInfo",
     "ArgsInfo",
     "ArgsListInfo",
+    "BuiltinVariablesInfo",
     "FeatureConstraintInfo",
-    "NestedArgsInfo",
 )
 
 visibility("public")
 
 def _cc_args_impl(ctx):
-    if not ctx.attr.args and not ctx.attr.env:
-        fail("cc_args requires at least one of args and env")
-
     actions = collect_action_types(ctx.attr.actions)
-    files = collect_files(ctx.attr.data)
-    requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
+
+    if not ctx.attr.args and not ctx.attr.nested and not ctx.attr.env:
+        fail("cc_args requires at least one of args, nested, and env")
 
     nested = None
-    if ctx.attr.args:
-        # TODO: This is temporary until cc_nested_args is implemented.
-        nested = NestedArgsInfo(
+    if ctx.attr.args or ctx.attr.nested:
+        nested = nested_args_provider_from_ctx(ctx)
+        validate_nested_args(
+            variables = ctx.attr._variables[BuiltinVariablesInfo].variables,
+            nested_args = nested,
+            actions = actions.to_list(),
             label = ctx.label,
-            nested = tuple(),
-            iterate_over = None,
-            files = files,
-            requires_types = {},
-            legacy_flag_group = flag_group(flags = ctx.attr.args),
         )
+        files = nested.files
+    else:
+        files = collect_files(ctx.attr.data)
+
+    requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
 
     args = ArgsInfo(
         label = ctx.label,
@@ -72,7 +79,7 @@
         ),
     ]
 
-cc_args = rule(
+_cc_args = rule(
     implementation = _cc_args_impl,
     attrs = {
         "actions": attr.label_list(
@@ -83,21 +90,6 @@
 See @rules_cc//cc/toolchains/actions:all for valid options.
 """,
         ),
-        "args": attr.string_list(
-            doc = """Arguments that should be added to the command-line.
-
-These are evaluated in order, with earlier args appearing earlier in the
-invocation of the underlying tool.
-""",
-        ),
-        "data": attr.label_list(
-            allow_files = True,
-            doc = """Files required to add this argument to the command-line.
-
-For example, a flag that sets the header directory might add the headers in that
-directory as additional files.
-        """,
-        ),
         "env": attr.string_dict(
             doc = "Environment variables to be added to the command-line.",
         ),
@@ -108,7 +100,10 @@
 If omitted, this flag set will be enabled unconditionally.
 """,
         ),
-    },
+        "_variables": attr.label(
+            default = "//cc/toolchains/variables:variables",
+        ),
+    } | NESTED_ARGS_ATTRS,
     provides = [ArgsInfo],
     doc = """Declares a list of arguments bound to a set of actions.
 
@@ -121,3 +116,5 @@
     )
 """,
 )
+
+cc_args = lambda **kwargs: args_wrapper_macro(rule = _cc_args, **kwargs)
diff --git a/cc/toolchains/impl/args_utils.bzl b/cc/toolchains/impl/args_utils.bzl
index 2ace6aa..55b4841 100644
--- a/cc/toolchains/impl/args_utils.bzl
+++ b/cc/toolchains/impl/args_utils.bzl
@@ -11,7 +11,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""."""
+"""Helper functions for working with args."""
+
+load(":variables.bzl", "get_type")
+
+visibility([
+    "//cc/toolchains",
+    "//tests/rule_based_toolchain/...",
+])
 
 def get_action_type(args_list, action_type):
     """Returns the corresponding entry in ArgsListInfo.by_action.
@@ -28,3 +35,87 @@
             return args
 
     return struct(action = action_type, args = tuple(), files = depset([]))
+
+def validate_nested_args(*, nested_args, variables, actions, label, fail = fail):
+    """Validates the typing for an nested_args invocation.
+
+    Args:
+        nested_args: (NestedArgsInfo) The nested_args to validate
+        variables: (Dict[str, VariableInfo]) A mapping from variable name to
+          the metadata (variable type and valid actions).
+        actions: (List[ActionTypeInfo]) The actions we require these variables
+          to be valid for.
+        label: (Label) The label of the rule we're currently validating.
+          Used for error messages.
+        fail: The fail function. Use for testing only.
+    """
+    stack = [(nested_args, {})]
+
+    for _ in range(9999999):
+        if not stack:
+            break
+        nested_args, overrides = stack.pop()
+        if nested_args.iterate_over != None or nested_args.unwrap_options:
+            # Make sure we don't keep using the same object.
+            overrides = dict(**overrides)
+
+        if nested_args.iterate_over != None:
+            type = get_type(
+                name = nested_args.iterate_over,
+                variables = variables,
+                overrides = overrides,
+                actions = actions,
+                args_label = label,
+                nested_label = nested_args.label,
+                fail = fail,
+            )
+            if type["name"] == "list":
+                # Rewrite the type of the thing we iterate over from a List[T]
+                # to a T.
+                overrides[nested_args.iterate_over] = type["elements"]
+            elif type["name"] == "option" and type["elements"]["name"] == "list":
+                # Rewrite Option[List[T]] to T.
+                overrides[nested_args.iterate_over] = type["elements"]["elements"]
+            else:
+                fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"]))
+
+        # 1) Validate variables marked with after_option_unwrap = False.
+        # 2) Unwrap Option[T] to T as required.
+        # 3) Validate variables marked with after_option_unwrap = True.
+        for after_option_unwrap in [False, True]:
+            for var_name, requirements in nested_args.requires_types.items():
+                for requirement in requirements:
+                    if requirement.after_option_unwrap == after_option_unwrap:
+                        type = get_type(
+                            name = var_name,
+                            variables = variables,
+                            overrides = overrides,
+                            actions = actions,
+                            args_label = label,
+                            nested_label = nested_args.label,
+                            fail = fail,
+                        )
+                        if type["name"] not in requirement.valid_types:
+                            fail("{msg}, but {var_name} has type {type}".format(
+                                var_name = var_name,
+                                msg = requirement.msg,
+                                type = type["repr"],
+                            ))
+
+            # Only unwrap the options after the first iteration of this loop.
+            if not after_option_unwrap:
+                for var in nested_args.unwrap_options:
+                    type = get_type(
+                        name = var,
+                        variables = variables,
+                        overrides = overrides,
+                        actions = actions,
+                        args_label = label,
+                        nested_label = nested_args.label,
+                        fail = fail,
+                    )
+                    if type["name"] == "option":
+                        overrides[var] = type["elements"]
+
+        for child in nested_args.nested:
+            stack.append((child, overrides))
diff --git a/cc/toolchains/impl/nested_args.bzl b/cc/toolchains/impl/nested_args.bzl
index dda7498..ed83cf1 100644
--- a/cc/toolchains/impl/nested_args.bzl
+++ b/cc/toolchains/impl/nested_args.bzl
@@ -13,8 +13,10 @@
 # limitations under the License.
 """Helper functions for working with args."""
 
+load("@bazel_skylib//lib:structs.bzl", "structs")
 load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
-load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo")
+load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo")
+load(":collect.bzl", "collect_files", "collect_provider")
 
 visibility([
     "//cc/toolchains",
@@ -48,6 +50,126 @@
     iterate_over = "//toolchains/variables:foo_list",
 """
 
+# @unsorted-dict-items.
+NESTED_ARGS_ATTRS = {
+    "args": attr.string_list(
+        doc = """json-encoded arguments to be added to the command-line.
+
+Usage:
+cc_args(
+    ...,
+    args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")]
+)
+
+This is equivalent to flag_group(flags = ["--foo", "%{foo}"])
+
+Mutually exclusive with nested.
+""",
+    ),
+    "nested": attr.label_list(
+        providers = [NestedArgsInfo],
+        doc = """nested_args that should be added on the command-line.
+
+Mutually exclusive with args.""",
+    ),
+    "data": attr.label_list(
+        allow_files = True,
+        doc = """Files required to add this argument to the command-line.
+
+For example, a flag that sets the header directory might add the headers in that
+directory as additional files.
+""",
+    ),
+    "variables": attr.label_list(
+        providers = [VariableInfo],
+        doc = "Variables to be used in substitutions",
+    ),
+    "iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"),
+    "requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"),
+    "requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"),
+    "requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"),
+    "requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"),
+    "requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"),
+    "requires_equal_value": attr.string(),
+}
+
+def args_wrapper_macro(*, name, rule, args = [], **kwargs):
+    """Invokes a rule by converting args to attributes.
+
+    Args:
+        name: (str) The name of the target.
+        rule: (rule) The rule to invoke. Either cc_args or cc_nested_args.
+        args: (List[str|Formatted]) A list of either strings, or function calls
+          from format.bzl. For example:
+            ["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")]
+        **kwargs: kwargs to pass through into the rule invocation.
+    """
+    out_args = []
+    vars = []
+    if type(args) != "list":
+        fail("Args must be a list in %s" % native.package_relative_label(name))
+    for arg in args:
+        if type(arg) == "string":
+            out_args.append(raw_string(arg))
+        elif getattr(arg, "format_type") == "format_arg":
+            arg = structs.to_dict(arg)
+            if arg["value"] == None:
+                out_args.append(arg)
+            else:
+                var = arg.pop("value")
+
+                # Swap the variable from a label to an index. This allows us to
+                # actually get the providers in a rule.
+                out_args.append(struct(value = len(vars), **arg))
+                vars.append(var)
+        else:
+            fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg))
+
+    rule(
+        name = name,
+        args = [json.encode(arg) for arg in out_args],
+        variables = vars,
+        **kwargs
+    )
+
+def _var(target):
+    if target == None:
+        return None
+    return target[VariableInfo].name
+
+# TODO: Consider replacing this with a subrule in the future. However, maybe not
+# for a long time, since it'll break compatibility with all bazel versions < 7.
+def nested_args_provider_from_ctx(ctx):
+    """Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS.
+
+    Args:
+        ctx: The rule context
+    Returns:
+        NestedArgsInfo
+    """
+    variables = collect_provider(ctx.attr.variables, VariableInfo)
+    args = []
+    for arg in ctx.attr.args:
+        arg = json.decode(arg)
+        if "value" in arg:
+            if arg["value"] != None:
+                arg["value"] = variables[arg["value"]]
+        args.append(struct(**arg))
+
+    return nested_args_provider(
+        label = ctx.label,
+        args = args,
+        nested = collect_provider(ctx.attr.nested, NestedArgsInfo),
+        files = collect_files(ctx.attr.data),
+        iterate_over = _var(ctx.attr.iterate_over),
+        requires_not_none = _var(ctx.attr.requires_not_none),
+        requires_none = _var(ctx.attr.requires_none),
+        requires_true = _var(ctx.attr.requires_true),
+        requires_false = _var(ctx.attr.requires_false),
+        requires_equal = _var(ctx.attr.requires_equal),
+        requires_equal_value = ctx.attr.requires_equal_value,
+    )
+
 def raw_string(s):
     """Constructs metadata for creating a raw string.
 
diff --git a/cc/toolchains/nested_args.bzl b/cc/toolchains/nested_args.bzl
new file mode 100644
index 0000000..e4e3d53
--- /dev/null
+++ b/cc/toolchains/nested_args.bzl
@@ -0,0 +1,45 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""All providers for rule-based bazel toolchain config."""
+
+load(
+    "//cc/toolchains/impl:nested_args.bzl",
+    "NESTED_ARGS_ATTRS",
+    "args_wrapper_macro",
+    "nested_args_provider_from_ctx",
+)
+load(
+    ":cc_toolchain_info.bzl",
+    "NestedArgsInfo",
+)
+
+visibility("public")
+
+_cc_nested_args = rule(
+    implementation = lambda ctx: [nested_args_provider_from_ctx(ctx)],
+    attrs = NESTED_ARGS_ATTRS,
+    provides = [NestedArgsInfo],
+    doc = """Declares a list of arguments bound to a set of actions.
+
+Roughly equivalent to ctx.actions.args()
+
+Examples:
+    cc_nested_args(
+        name = "warnings_as_errors",
+        args = ["-Werror"],
+    )
+""",
+)
+
+cc_nested_args = lambda **kwargs: args_wrapper_macro(rule = _cc_nested_args, **kwargs)
diff --git a/tests/rule_based_toolchain/variables/BUILD b/tests/rule_based_toolchain/variables/BUILD
index 80928c7..5f7a5a6 100644
--- a/tests/rule_based_toolchain/variables/BUILD
+++ b/tests/rule_based_toolchain/variables/BUILD
@@ -1,3 +1,5 @@
+load("//cc/toolchains:format.bzl", "format_arg")
+load("//cc/toolchains:nested_args.bzl", "cc_nested_args")
 load("//cc/toolchains/impl:variables.bzl", "cc_builtin_variables", "cc_variable", "types")
 load("//tests/rule_based_toolchain:analysis_test_suite.bzl", "analysis_test_suite")
 load(":variables_test.bzl", "TARGETS", "TESTS")
@@ -8,6 +10,11 @@
 )
 
 cc_variable(
+    name = "optional_list",
+    type = types.option(types.list(types.string)),
+)
+
+cc_variable(
     name = "str_list",
     type = types.list(types.string),
 )
@@ -28,15 +35,104 @@
 
 cc_variable(
     name = "struct_list",
+    actions = ["//tests/rule_based_toolchain/actions:c_compile"],
     type = types.list(types.struct(
         nested_str = types.string,
         nested_str_list = types.list(types.string),
     )),
 )
 
+cc_variable(
+    name = "struct_list.nested_str_list",
+    type = types.unknown,
+)
+
+# Dots in the name confuse the test rules.
+# It would end up generating targets.struct_list.nested_str_list.
+alias(
+    name = "nested_str_list",
+    actual = ":struct_list.nested_str_list",
+)
+
+cc_nested_args(
+    name = "simple_str",
+    args = [format_arg("%s", ":str")],
+)
+
+cc_nested_args(
+    name = "list_not_allowed",
+    args = [format_arg("%s", ":str_list")],
+)
+
+cc_nested_args(
+    name = "iterate_over_list",
+    args = [format_arg("%s")],
+    iterate_over = ":str_list",
+)
+
+cc_nested_args(
+    name = "iterate_over_non_list",
+    args = ["--foo"],
+    iterate_over = ":str",
+)
+
+cc_nested_args(
+    name = "str_not_a_bool",
+    args = ["--foo"],
+    requires_true = ":str",
+)
+
+cc_nested_args(
+    name = "str_equal",
+    args = ["--foo"],
+    requires_equal = ":str",
+    requires_equal_value = "bar",
+)
+
+cc_nested_args(
+    name = "inner_iter",
+    args = [format_arg("%s")],
+    iterate_over = ":struct_list.nested_str_list",
+)
+
+cc_nested_args(
+    name = "outer_iter",
+    iterate_over = ":struct_list",
+    nested = [":inner_iter"],
+)
+
+cc_nested_args(
+    name = "bad_inner_iter",
+    args = [format_arg("%s", ":struct_list.nested_str_list")],
+)
+
+cc_nested_args(
+    name = "bad_outer_iter",
+    iterate_over = ":struct_list",
+    nested = [":bad_inner_iter"],
+)
+
+cc_nested_args(
+    name = "bad_nested_optional",
+    args = [format_arg("%s", ":str_option")],
+)
+
+cc_nested_args(
+    name = "good_nested_optional",
+    args = [format_arg("%s", ":str_option")],
+    requires_not_none = ":str_option",
+)
+
+cc_nested_args(
+    name = "optional_list_iter",
+    args = ["--foo"],
+    iterate_over = ":optional_list",
+)
+
 cc_builtin_variables(
     name = "variables",
     srcs = [
+        ":optional_list",
         ":str",
         ":str_list",
         ":str_option",
diff --git a/tests/rule_based_toolchain/variables/variables_test.bzl b/tests/rule_based_toolchain/variables/variables_test.bzl
index a3cf843..98a64fd 100644
--- a/tests/rule_based_toolchain/variables/variables_test.bzl
+++ b/tests/rule_based_toolchain/variables/variables_test.bzl
@@ -13,13 +13,20 @@
 # limitations under the License.
 """Tests for variables rule."""
 
-load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "VariableInfo")
+load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "NestedArgsInfo", "VariableInfo")
+load("//cc/toolchains/impl:args_utils.bzl", _validate_nested_args = "validate_nested_args")
+load(
+    "//cc/toolchains/impl:nested_args.bzl",
+    "FORMAT_ARGS_ERR",
+    "REQUIRES_TRUE_ERR",
+)
 load("//cc/toolchains/impl:variables.bzl", "types", _get_type = "get_type")
 load("//tests/rule_based_toolchain:subjects.bzl", "result_fn_wrapper", "subjects")
 
 visibility("private")
 
 get_type = result_fn_wrapper(_get_type)
+validate_nested_args = result_fn_wrapper(_validate_nested_args)
 
 _ARGS_LABEL = Label("//:args")
 _NESTED_LABEL = Label("//:nested_vars")
@@ -56,6 +63,7 @@
 
     expect_type("unknown").err().contains(
         """The variable unknown does not exist. Did you mean one of the following?
+optional_list
 str
 str_list
 """,
@@ -110,11 +118,74 @@
         },
     ).ok().equals(types.string)
 
+def _variable_validation_test(env, targets):
+    c_compile = targets.c_compile[ActionTypeInfo]
+    cpp_compile = targets.cpp_compile[ActionTypeInfo]
+    variables = targets.variables[BuiltinVariablesInfo].variables
+
+    def _expect_validated(target, expr = None, actions = []):
+        return env.expect.that_value(
+            validate_nested_args(
+                nested_args = target[NestedArgsInfo],
+                variables = variables,
+                actions = actions,
+                label = _ARGS_LABEL,
+            ),
+            expr = expr,
+            # Type is Result[None]
+            factory = subjects.result(subjects.unknown),
+        )
+
+    _expect_validated(targets.simple_str, expr = "simple_str").ok()
+    _expect_validated(targets.list_not_allowed).err().equals(
+        FORMAT_ARGS_ERR + ", but str_list has type List[string]",
+    )
+    _expect_validated(targets.iterate_over_list, expr = "iterate_over_list").ok()
+    _expect_validated(targets.iterate_over_non_list, expr = "iterate_over_non_list").err().equals(
+        "Attempting to iterate over str, but it was not a list - it was a string",
+    )
+    _expect_validated(targets.str_not_a_bool, expr = "str_not_a_bool").err().equals(
+        REQUIRES_TRUE_ERR + ", but str has type string",
+    )
+    _expect_validated(targets.str_equal, expr = "str_equal").ok()
+    _expect_validated(targets.inner_iter, expr = "inner_iter_standalone").err().equals(
+        'Attempted to access "struct_list.nested_str_list", but "struct_list" was not a struct - it had type List[struct(nested_str=string, nested_str_list=List[string])]. Maybe you meant to use iterate_over.',
+    )
+
+    _expect_validated(targets.outer_iter, actions = [c_compile], expr = "outer_iter_valid_action").ok()
+    _expect_validated(targets.outer_iter, actions = [c_compile, cpp_compile], expr = "outer_iter_missing_action").err().equals(
+        "The variable %s is inaccessible from the action %s. This is required because it is referenced in %s, which is included by %s, which references that action" % (targets.struct_list.label, cpp_compile.label, targets.outer_iter.label, _ARGS_LABEL),
+    )
+
+    _expect_validated(targets.bad_outer_iter, expr = "bad_outer_iter").err().equals(
+        FORMAT_ARGS_ERR + ", but struct_list.nested_str_list has type List[string]",
+    )
+
+    _expect_validated(targets.optional_list_iter, expr = "optional_list_iter").ok()
+
+    _expect_validated(targets.bad_nested_optional, expr = "bad_nested_optional").err().equals(
+        FORMAT_ARGS_ERR + ", but str_option has type Option[string]",
+    )
+    _expect_validated(targets.good_nested_optional, expr = "good_nested_optional").ok()
+
 TARGETS = [
     "//tests/rule_based_toolchain/actions:c_compile",
     "//tests/rule_based_toolchain/actions:cpp_compile",
+    ":bad_nested_optional",
+    ":bad_outer_iter",
+    ":good_nested_optional",
+    ":inner_iter",
+    ":iterate_over_list",
+    ":iterate_over_non_list",
+    ":list_not_allowed",
+    ":nested_str_list",
+    ":optional_list_iter",
+    ":outer_iter",
+    ":simple_str",
     ":str",
+    ":str_equal",
     ":str_list",
+    ":str_not_a_bool",
     ":str_option",
     ":struct",
     ":struct_list",
@@ -125,4 +196,5 @@
 TESTS = {
     "types_represent_correctly_test": _types_represent_correctly_test,
     "get_types_test": _get_types_test,
+    "variable_validation_test": _variable_validation_test,
 }