Gather variable metadata for the new rule-based toolchain.

BEGIN_PUBLIC
Gather variable metadata for the new rule-based toolchain.
END_PUBLIC

PiperOrigin-RevId: 622000877
Change-Id: I5b2ea6c363fc43fd44e60ffc8fa7ae041545337e
diff --git a/cc/toolchains/cc_toolchain_info.bzl b/cc/toolchains/cc_toolchain_info.bzl
index 0429ad2..3a499f6 100644
--- a/cc/toolchains/cc_toolchain_info.bzl
+++ b/cc/toolchains/cc_toolchain_info.bzl
@@ -73,6 +73,7 @@
         "files": "(depset[File]) The files required to use this variable",
         "requires_types": "(dict[str, str]) A mapping from variables to their expected type name (not type). This means that we can require the generic type Option, rather than an Option[T]",
         "legacy_flag_group": "(flag_group) The flag_group this corresponds to",
+        "unwrap_options": "(List[str]) A list of variables for which we should unwrap the option. For example, if a user writes `requires_not_none = \":foo\"`, then we change the type of foo from Option[str] to str",
     },
 )
 
@@ -83,7 +84,7 @@
         "label": "(Label) The label defining this provider. Place in error messages to simplify debugging",
         "actions": "(depset[ActionTypeInfo]) The set of actions this is associated with",
         "requires_any_of": "(Sequence[FeatureConstraintInfo]) This will be enabled if any of the listed predicates are met. Equivalent to with_features",
-        "nested": "(Optional[NestedArgsInfo]) The args to expand. Equivalent to a flag group.",
+        "nested": "(Optional[NestedArgsInfo]) The args expand. Equivalent to a flag group.",
         "files": "(depset[File]) Files required for the args",
         "env": "(dict[str, str]) Environment variables to apply",
     },
diff --git a/cc/toolchains/format.bzl b/cc/toolchains/format.bzl
new file mode 100644
index 0000000..bdbb0c8
--- /dev/null
+++ b/cc/toolchains/format.bzl
@@ -0,0 +1,26 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to format arguments for the cc toolchain"""
+
+def format_arg(format, value = None):
+    """Generate metadata to format a variable with a given value.
+
+    Args:
+      format: (str) The format string
+      value: (Optional[Label]) The variable to format. Any is used because it can
+        be any representation of a variable.
+    Returns:
+      A struct corresponding to the formatted variable.
+    """
+    return struct(format_type = "format_arg", format = format, value = value)
diff --git a/cc/toolchains/impl/nested_args.bzl b/cc/toolchains/impl/nested_args.bzl
new file mode 100644
index 0000000..dda7498
--- /dev/null
+++ b/cc/toolchains/impl/nested_args.bzl
@@ -0,0 +1,265 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Helper functions for working with args."""
+
+load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
+load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo")
+
+visibility([
+    "//cc/toolchains",
+    "//tests/rule_based_toolchain/...",
+])
+
+REQUIRES_MUTUALLY_EXCLUSIVE_ERR = "requires_none, requires_not_none, requires_true, requires_false, and requires_equal are mutually exclusive"
+REQUIRES_NOT_NONE_ERR = "requires_not_none only works on options"
+REQUIRES_NONE_ERR = "requires_none only works on options"
+REQUIRES_TRUE_ERR = "requires_true only works on bools"
+REQUIRES_FALSE_ERR = "requires_false only works on bools"
+REQUIRES_EQUAL_ERR = "requires_equal only works on strings"
+REQUIRES_EQUAL_VALUE_ERR = "When requires_equal is provided, you must also provide requires_equal_value to specify what it should be equal to"
+FORMAT_ARGS_ERR = "format_args can only format strings, files, or directories"
+
+_NOT_ESCAPED_FMT = "%% should always either of the form %%s, or escaped with %%%%. Instead, got %r"
+
+_EXAMPLE = """
+
+cc_args(
+    ...,
+    args = [format_arg("--foo=%s", "//cc/toolchains/variables:foo")]
+)
+
+or
+
+cc_args(
+    ...,
+    # If foo_list contains ["a", "b"], then this expands to ["--foo", "+a", "--foo", "+b"].
+    args = ["--foo", format_arg("+%s")],
+    iterate_over = "//toolchains/variables:foo_list",
+"""
+
+def raw_string(s):
+    """Constructs metadata for creating a raw string.
+
+    Args:
+      s: (str) The string to input.
+    Returns:
+      Metadata suitable for format_variable.
+    """
+    return struct(format_type = "raw", format = s)
+
+def format_string_indexes(s, fail = fail):
+    """Gets the index of a '%s' in a string.
+
+    Args:
+      s: (str) The string
+      fail: The fail function. Used for tests
+
+    Returns:
+      List[int] The indexes of the '%s' in the string
+    """
+    indexes = []
+    escaped = False
+    for i in range(len(s)):
+        if not escaped and s[i] == "%":
+            escaped = True
+        elif escaped:
+            if s[i] == "{":
+                fail('Using the old mechanism for variables, %%{variable}, but we instead use format_arg("--foo=%%s", "//cc/toolchains/variables:<variable>"). Got %r' % s)
+            elif s[i] == "s":
+                indexes.append(i - 1)
+            elif s[i] != "%":
+                fail(_NOT_ESCAPED_FMT % s)
+            escaped = False
+    if escaped:
+        return fail(_NOT_ESCAPED_FMT % s)
+    return indexes
+
+def format_variable(arg, iterate_over, fail = fail):
+    """Lists all of the variables referenced by an argument.
+
+    Eg: referenced_variables([
+        format_arg("--foo", None),
+        format_arg("--bar=%s", ":bar")
+    ]) => ["--foo", "--bar=%{bar}"]
+
+    Args:
+      arg: [Formatted] The command-line arguments, as created by the format_arg function.
+      iterate_over: (Optional[str]) The name of the variable we're iterating over.
+      fail: The fail function. Used for tests
+
+    Returns:
+      A string defined to be compatible with flag groups.
+    """
+    indexes = format_string_indexes(arg.format, fail = fail)
+    if arg.format_type == "raw":
+        if indexes:
+            return fail("Can't use %s with a raw string. Either escape it with %%s or use format_arg, like the following examples:" + _EXAMPLE)
+        return arg.format
+    else:
+        if len(indexes) == 0:
+            return fail('format_arg requires a "%%s" in the format string, but got %r' % arg.format)
+        elif len(indexes) > 1:
+            return fail("Only one %%s can be used in a format string, but got %r" % arg.format)
+
+        if arg.value == None:
+            if iterate_over == None:
+                return fail("format_arg requires either a variable to format, or iterate_over must be provided. For example:" + _EXAMPLE)
+            var = iterate_over
+        else:
+            var = arg.value.name
+
+        index = indexes[0]
+        return arg.format[:index] + "%{" + var + "}" + arg.format[index + 2:]
+
+def nested_args_provider(
+        *,
+        label,
+        args = [],
+        nested = [],
+        files = depset([]),
+        iterate_over = None,
+        requires_not_none = None,
+        requires_none = None,
+        requires_true = None,
+        requires_false = None,
+        requires_equal = None,
+        requires_equal_value = "",
+        fail = fail):
+    """Creates a validated NestedArgsInfo.
+
+    Does not validate types, as you can't know the type of a variable until
+    you have a cc_args wrapping it, because the outer layers can change that
+    type using iterate_over.
+
+    Args:
+        label: (Label) The context we are currently evaluating in. Used for
+          error messages.
+        args: (List[str]) The command-line arguments to add.
+        nested: (List[NestedArgsInfo]) command-line arguments to expand.
+        files: (depset[File]) Files required for this set of command-line args.
+        iterate_over: (Optional[str]) Variable to iterate over
+        requires_not_none: (Optional[str]) If provided, this NestedArgsInfo will
+          be ignored if the variable is None
+        requires_none: (Optional[str]) If provided, this NestedArgsInfo will
+          be ignored if the variable is not None
+        requires_true: (Optional[str]) If provided, this NestedArgsInfo will
+          be ignored if the variable is false
+        requires_false: (Optional[str]) If provided, this NestedArgsInfo will
+          be ignored if the variable is true
+        requires_equal: (Optional[str]) If provided, this NestedArgsInfo will
+          be ignored if the variable is not equal to requires_equal_value.
+        requires_equal_value: (str) The value to compare the requires_equal
+          variable with
+        fail: A fail function. Use only for testing.
+    Returns:
+        NestedArgsInfo
+    """
+    if bool(args) == bool(nested):
+        fail("Exactly one of args and nested must be provided")
+
+    transitive_files = [ea.files for ea in nested]
+    transitive_files.append(files)
+
+    has_value = [attr for attr in [
+        requires_not_none,
+        requires_none,
+        requires_true,
+        requires_false,
+        requires_equal,
+    ] if attr != None]
+
+    # We may want to reconsider this down the line, but it's easier to open up
+    # an API than to lock down an API.
+    if len(has_value) > 1:
+        fail(REQUIRES_MUTUALLY_EXCLUSIVE_ERR)
+
+    kwargs = {}
+    requires_types = {}
+    if nested:
+        kwargs["flag_groups"] = [ea.legacy_flag_group for ea in nested]
+
+    unwrap_options = []
+
+    if iterate_over:
+        kwargs["iterate_over"] = iterate_over
+
+    if requires_not_none:
+        kwargs["expand_if_available"] = requires_not_none
+        requires_types.setdefault(requires_not_none, []).append(struct(
+            msg = REQUIRES_NOT_NONE_ERR,
+            valid_types = ["option"],
+            after_option_unwrap = False,
+        ))
+        unwrap_options.append(requires_not_none)
+    elif requires_none:
+        kwargs["expand_if_not_available"] = requires_none
+        requires_types.setdefault(requires_none, []).append(struct(
+            msg = REQUIRES_NONE_ERR,
+            valid_types = ["option"],
+            after_option_unwrap = False,
+        ))
+    elif requires_true:
+        kwargs["expand_if_true"] = requires_true
+        requires_types.setdefault(requires_true, []).append(struct(
+            msg = REQUIRES_TRUE_ERR,
+            valid_types = ["bool"],
+            after_option_unwrap = True,
+        ))
+        unwrap_options.append(requires_true)
+    elif requires_false:
+        kwargs["expand_if_false"] = requires_false
+        requires_types.setdefault(requires_false, []).append(struct(
+            msg = REQUIRES_FALSE_ERR,
+            valid_types = ["bool"],
+            after_option_unwrap = True,
+        ))
+        unwrap_options.append(requires_false)
+    elif requires_equal:
+        if not requires_equal_value:
+            fail(REQUIRES_EQUAL_VALUE_ERR)
+        kwargs["expand_if_equal"] = variable_with_value(
+            name = requires_equal,
+            value = requires_equal_value,
+        )
+        unwrap_options.append(requires_equal)
+        requires_types.setdefault(requires_equal, []).append(struct(
+            msg = REQUIRES_EQUAL_ERR,
+            valid_types = ["string"],
+            after_option_unwrap = True,
+        ))
+
+    for arg in args:
+        if arg.format_type != "raw":
+            var_name = arg.value.name if arg.value != None else iterate_over
+            requires_types.setdefault(var_name, []).append(struct(
+                msg = FORMAT_ARGS_ERR,
+                valid_types = ["string", "file", "directory"],
+                after_option_unwrap = True,
+            ))
+
+    if args:
+        kwargs["flags"] = [
+            format_variable(arg, iterate_over = iterate_over, fail = fail)
+            for arg in args
+        ]
+
+    return NestedArgsInfo(
+        label = label,
+        nested = nested,
+        files = depset(transitive = transitive_files),
+        iterate_over = iterate_over,
+        unwrap_options = unwrap_options,
+        requires_types = requires_types,
+        legacy_flag_group = flag_group(**kwargs),
+    )
diff --git a/tests/rule_based_toolchain/nested_args/BUILD b/tests/rule_based_toolchain/nested_args/BUILD
new file mode 100644
index 0000000..30e75ed
--- /dev/null
+++ b/tests/rule_based_toolchain/nested_args/BUILD
@@ -0,0 +1,14 @@
+load("//cc/toolchains/impl:variables.bzl", "cc_variable", "types")
+load("//tests/rule_based_toolchain:analysis_test_suite.bzl", "analysis_test_suite")
+load(":nested_args_test.bzl", "TARGETS", "TESTS")
+
+cc_variable(
+    name = "foo",
+    type = types.string,
+)
+
+analysis_test_suite(
+    name = "test_suite",
+    targets = TARGETS,
+    tests = TESTS,
+)
diff --git a/tests/rule_based_toolchain/nested_args/nested_args_test.bzl b/tests/rule_based_toolchain/nested_args/nested_args_test.bzl
new file mode 100644
index 0000000..96a361c
--- /dev/null
+++ b/tests/rule_based_toolchain/nested_args/nested_args_test.bzl
@@ -0,0 +1,205 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the cc_args rule."""
+
+load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
+load("//cc/toolchains:cc_toolchain_info.bzl", "VariableInfo")
+load("//cc/toolchains:format.bzl", "format_arg")
+load(
+    "//cc/toolchains/impl:nested_args.bzl",
+    "FORMAT_ARGS_ERR",
+    "REQUIRES_EQUAL_ERR",
+    "REQUIRES_MUTUALLY_EXCLUSIVE_ERR",
+    "REQUIRES_NONE_ERR",
+    "format_string_indexes",
+    "format_variable",
+    "nested_args_provider",
+    "raw_string",
+)
+load("//tests/rule_based_toolchain:subjects.bzl", "result_fn_wrapper", "subjects")
+
+visibility("private")
+
+def _expect_that_nested(env, expr = None, **kwargs):
+    return env.expect.that_value(
+        expr = expr,
+        value = result_fn_wrapper(nested_args_provider)(
+            label = Label("//:args"),
+            **kwargs
+        ),
+        factory = subjects.result(subjects.NestedArgsInfo),
+    )
+
+def _expect_that_formatted(env, var, iterate_over = None, expr = None):
+    return env.expect.that_value(
+        result_fn_wrapper(format_variable)(var, iterate_over),
+        factory = subjects.result(subjects.str),
+        expr = expr or "format_variable(var=%r, iterate_over=%r" % (var, iterate_over),
+    )
+
+def _expect_that_format_string_indexes(env, var, expr = None):
+    return env.expect.that_value(
+        result_fn_wrapper(format_string_indexes)(var),
+        factory = subjects.result(subjects.collection),
+        expr = expr or "format_string_indexes(%r)" % var,
+    )
+
+def _format_string_indexes_test(env, _):
+    _expect_that_format_string_indexes(env, "foo").ok().contains_exactly([])
+    _expect_that_format_string_indexes(env, "%%").ok().contains_exactly([])
+    _expect_that_format_string_indexes(env, "%").err().equals(
+        '% should always either of the form %s, or escaped with %%. Instead, got "%"',
+    )
+    _expect_that_format_string_indexes(env, "%a").err().equals(
+        '% should always either of the form %s, or escaped with %%. Instead, got "%a"',
+    )
+    _expect_that_format_string_indexes(env, "%s").ok().contains_exactly([0])
+    _expect_that_format_string_indexes(env, "%%%s%s").ok().contains_exactly([2, 4])
+    _expect_that_format_string_indexes(env, "%%{").ok().contains_exactly([])
+    _expect_that_format_string_indexes(env, "%%s").ok().contains_exactly([])
+    _expect_that_format_string_indexes(env, "%{foo}").err().equals(
+        'Using the old mechanism for variables, %{variable}, but we instead use format_arg("--foo=%s", "//cc/toolchains/variables:<variable>"). Got "%{foo}"',
+    )
+
+def _formats_raw_strings_test(env, _):
+    _expect_that_formatted(
+        env,
+        raw_string("foo"),
+    ).ok().equals("foo")
+    _expect_that_formatted(
+        env,
+        raw_string("%s"),
+    ).err().contains("Can't use %s with a raw string. Either escape it with %%s or use format_arg")
+
+def _formats_variables_test(env, targets):
+    _expect_that_formatted(
+        env,
+        format_arg("ab %s cd", targets.foo[VariableInfo]),
+    ).ok().equals("ab %{foo} cd")
+
+    _expect_that_formatted(
+        env,
+        format_arg("foo", targets.foo[VariableInfo]),
+    ).err().equals('format_arg requires a "%s" in the format string, but got "foo"')
+    _expect_that_formatted(
+        env,
+        format_arg("%s%s", targets.foo[VariableInfo]),
+    ).err().equals('Only one %s can be used in a format string, but got "%s%s"')
+
+    _expect_that_formatted(
+        env,
+        format_arg("%s"),
+        iterate_over = "foo",
+    ).ok().equals("%{foo}")
+    _expect_that_formatted(
+        env,
+        format_arg("%s"),
+    ).err().contains("format_arg requires either a variable to format, or iterate_over must be provided")
+
+def _iterate_over_test(env, _):
+    inner = _expect_that_nested(
+        env,
+        args = [raw_string("--foo")],
+    ).ok().actual
+    env.expect.that_str(inner.legacy_flag_group).equals(flag_group(flags = ["--foo"]))
+
+    nested = _expect_that_nested(
+        env,
+        nested = [inner],
+        iterate_over = "my_list",
+    ).ok()
+    nested.iterate_over().some().equals("my_list")
+    nested.legacy_flag_group().equals(flag_group(
+        iterate_over = "my_list",
+        flag_groups = [inner.legacy_flag_group],
+    ))
+    nested.requires_types().contains_exactly({})
+
+def _requires_types_test(env, targets):
+    _expect_that_nested(
+        env,
+        requires_not_none = "abc",
+        requires_none = "def",
+        args = [raw_string("--foo")],
+        expr = "mutually_exclusive",
+    ).err().equals(REQUIRES_MUTUALLY_EXCLUSIVE_ERR)
+
+    _expect_that_nested(
+        env,
+        requires_none = "var",
+        args = [raw_string("--foo")],
+        expr = "requires_none",
+    ).ok().requires_types().contains_exactly(
+        {"var": [struct(
+            msg = REQUIRES_NONE_ERR,
+            valid_types = ["option"],
+            after_option_unwrap = False,
+        )]},
+    )
+
+    _expect_that_nested(
+        env,
+        args = [raw_string("foo %s baz")],
+        expr = "no_variable",
+    ).err().contains("Can't use %s with a raw string")
+
+    _expect_that_nested(
+        env,
+        args = [format_arg("foo %s baz", targets.foo[VariableInfo])],
+        expr = "type_validation",
+    ).ok().requires_types().contains_exactly(
+        {"foo": [struct(
+            msg = FORMAT_ARGS_ERR,
+            valid_types = ["string", "file", "directory"],
+            after_option_unwrap = True,
+        )]},
+    )
+
+    nested = _expect_that_nested(
+        env,
+        requires_equal = "foo",
+        requires_equal_value = "value",
+        args = [format_arg("--foo=%s", targets.foo[VariableInfo])],
+        expr = "type_and_requires_equal_validation",
+    ).ok()
+    nested.requires_types().contains_exactly(
+        {"foo": [
+            struct(
+                msg = REQUIRES_EQUAL_ERR,
+                valid_types = ["string"],
+                after_option_unwrap = True,
+            ),
+            struct(
+                msg = FORMAT_ARGS_ERR,
+                valid_types = ["string", "file", "directory"],
+                after_option_unwrap = True,
+            ),
+        ]},
+    )
+    nested.legacy_flag_group().equals(flag_group(
+        expand_if_equal = variable_with_value(name = "foo", value = "value"),
+        flags = ["--foo=%{foo}"],
+    ))
+
+TARGETS = [
+    ":foo",
+]
+
+TESTS = {
+    "format_string_indexes_test": _format_string_indexes_test,
+    "formats_raw_strings_test": _formats_raw_strings_test,
+    "formats_variables_test": _formats_variables_test,
+    "iterate_over_test": _iterate_over_test,
+    "requires_types_test": _requires_types_test,
+}
diff --git a/tests/rule_based_toolchain/subjects.bzl b/tests/rule_based_toolchain/subjects.bzl
index 628996e..f42d5d7 100644
--- a/tests/rule_based_toolchain/subjects.bzl
+++ b/tests/rule_based_toolchain/subjects.bzl
@@ -112,6 +112,7 @@
     iterate_over = optional_subject(_subjects.str),
     legacy_flag_group = unknown_subject,
     requires_types = _subjects.dict,
+    unwrap_options = _subjects.collection,
 )
 
 # buildifier: disable=name-conventions