Migrate repeated expand_if_(all|none)_available into nested flag_groups

Crosstool in Starlark assumes these fields as singular, not collections. This cl updates the migration script to prepare crosstool in proto for this.

https://github.com/bazelbuild/bazel/issues/6861
https://github.com/bazelbuild/bazel/issues/5883

RELNOTES: None.
PiperOrigin-RevId: 233041028
diff --git a/tools/migration/legacy_fields_migration_lib.py b/tools/migration/legacy_fields_migration_lib.py
index a7f64be..a1da3cf 100644
--- a/tools/migration/legacy_fields_migration_lib.py
+++ b/tools/migration/legacy_fields_migration_lib.py
@@ -87,6 +87,8 @@
   for toolchain in crosstool.toolchain:
     _ = [_migrate_expand_if_all_available(f) for f in toolchain.feature]
     _ = [_migrate_expand_if_all_available(ac) for ac in toolchain.action_config]
+    _ = [_migrate_repeated_expands(f) for f in toolchain.feature]
+    _ = [_migrate_repeated_expands(ac) for ac in toolchain.action_config]
 
     if (toolchain.dynamic_library_linker_flag or
         _contains_dynamic_flags(toolchain)) and not _contains_feature(
@@ -437,7 +439,7 @@
 
 
 def _migrate_expand_if_all_available(message):
-  """Move expand_if_all_available fields to flag_groups."""
+  """Move expand_if_all_available field to flag_groups."""
   for flag_set in message.flag_set:
     if flag_set.expand_if_all_available:
       for flag_group in flag_set.flag_group:
@@ -448,6 +450,45 @@
       flag_set.ClearField("expand_if_all_available")
 
 
+def _migrate_repeated_expands(message):
+  """Replace repeated legacy fields with nesting."""
+  todo_queue = []
+  for flag_set in message.flag_set:
+    todo_queue.extend(flag_set.flag_group)
+  while todo_queue:
+    flag_group = todo_queue.pop()
+    todo_queue.extend(flag_group.flag_group)
+    if len(flag_group.expand_if_all_available) <= 1 and len(
+        flag_group.expand_if_none_available) <= 1:
+      continue
+
+    current_children = flag_group.flag_group
+    current_flags = flag_group.flag
+    flag_group.ClearField("flag_group")
+    flag_group.ClearField("flag")
+
+    new_flag_group = flag_group.flag_group.add()
+    new_flag_group.flag_group.extend(current_children)
+    new_flag_group.flag.extend(current_flags)
+
+    if len(flag_group.expand_if_all_available) > 1:
+      expands_to_move = flag_group.expand_if_all_available[1:]
+      flag_group.expand_if_all_available[:] = [
+          flag_group.expand_if_all_available[0]
+      ]
+      new_flag_group.expand_if_all_available.extend(expands_to_move)
+
+    if len(flag_group.expand_if_none_available) > 1:
+      expands_to_move = flag_group.expand_if_none_available[1:]
+      flag_group.expand_if_none_available[:] = [
+          flag_group.expand_if_none_available[0]
+      ]
+      new_flag_group.expand_if_none_available.extend(expands_to_move)
+
+    todo_queue.append(new_flag_group)
+    todo_queue.append(flag_group)
+
+
 def _contains_dynamic_flags(toolchain):
   for lmf in toolchain.linking_mode_flags:
     mode = crosstool_config_pb2.LinkingMode.Name(lmf.mode)
diff --git a/tools/migration/legacy_fields_migration_lib_test.py b/tools/migration/legacy_fields_migration_lib_test.py
index 298b72e..d6603a3 100644
--- a/tools/migration/legacy_fields_migration_lib_test.py
+++ b/tools/migration/legacy_fields_migration_lib_test.py
@@ -1023,8 +1023,7 @@
             flag_group {
               flag: '%{foo}'
             }
-            flag_group {
-              expand_if_all_available: 'bar'
+            flag_group {              
               flag: 'bar'
             }
           }
@@ -1038,7 +1037,6 @@
               flag: '%{foo}'
             }
             flag_group {
-              expand_if_all_available: 'bar'
               flag: 'bar'
             }
           }
@@ -1056,7 +1054,7 @@
         .expand_if_all_available, ["foo"])
     self.assertEqual(
         output.action_config[0].flag_set[0].flag_group[1]
-        .expand_if_all_available, ["bar", "foo"])
+        .expand_if_all_available, ["foo"])
 
     self.assertEqual(output.feature[0].name, "something_else")
     self.assertEqual(len(output.feature[0].flag_set), 1)
@@ -1068,7 +1066,139 @@
         ["foo"])
     self.assertEqual(
         output.feature[0].flag_set[0].flag_group[1].expand_if_all_available,
-        ["bar", "foo"])
+        ["foo"])
+
+  def test_migrate_repeated_expand_if_all_available_from_flag_groups(self):
+    crosstool = make_crosstool("""
+          action_config {
+            action_name: 'something'
+            config_name: 'something'
+            flag_set {
+              flag_group {
+                expand_if_all_available: 'foo'
+                expand_if_all_available: 'bar'
+                flag: '%{foo}'
+              }
+              flag_group {
+                expand_if_none_available: 'foo'
+                expand_if_none_available: 'bar'
+                flag: 'bar'
+              }
+            }
+          }
+          feature {
+            name: 'something_else'
+            flag_set {
+              action: 'c-compile'
+              flag_group {
+                expand_if_all_available: 'foo'
+                expand_if_all_available: 'bar'
+                flag: '%{foo}'
+              }
+              flag_group {
+                expand_if_none_available: 'foo'
+                expand_if_none_available: 'bar'
+                flag: 'bar'
+              }
+            }
+          }
+          """)
+    migrate_legacy_fields(crosstool)
+    output = crosstool.toolchain[0]
+    self.assertEqual(output.action_config[0].action_name, "something")
+    self.assertEqual(len(output.action_config[0].flag_set), 1)
+    self.assertEqual(
+        len(output.action_config[0].flag_set[0].expand_if_all_available), 0)
+    self.assertEqual(len(output.action_config[0].flag_set[0].flag_group), 2)
+    self.assertEqual(
+        output.action_config[0].flag_set[0].flag_group[0]
+        .expand_if_all_available, ["foo"])
+    self.assertEqual(
+        output.action_config[0].flag_set[0].flag_group[0].flag_group[0]
+        .expand_if_all_available, ["bar"])
+    self.assertEqual(
+        output.action_config[0].flag_set[0].flag_group[1]
+        .expand_if_none_available, ["foo"])
+    self.assertEqual(
+        output.action_config[0].flag_set[0].flag_group[1].flag_group[0]
+        .expand_if_none_available, ["bar"])
+
+    self.assertEqual(output.feature[0].name, "something_else")
+    self.assertEqual(len(output.feature[0].flag_set), 1)
+    self.assertEqual(
+        len(output.feature[0].flag_set[0].expand_if_all_available), 0)
+    self.assertEqual(len(output.feature[0].flag_set[0].flag_group), 2)
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[0].expand_if_all_available,
+        ["foo"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[0].flag_group[0]
+        .expand_if_all_available, ["bar"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].expand_if_none_available,
+        ["foo"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0]
+        .expand_if_none_available, ["bar"])
+
+  def test_migrate_repeated_expands_from_nested_flag_groups(self):
+    crosstool = make_crosstool("""
+          feature {
+            name: 'something'
+            flag_set {
+              action: 'c-compile'
+              flag_group {
+                flag_group {
+                  expand_if_all_available: 'foo'
+                  expand_if_all_available: 'bar'
+                  flag: '%{foo}'
+                }
+              }
+              flag_group {
+                flag_group {
+                  expand_if_all_available: 'foo'
+                  expand_if_all_available: 'bar'
+                  expand_if_none_available: 'foo'
+                  expand_if_none_available: 'bar'
+                  flag: '%{foo}'
+                }
+              }
+            }
+          }
+          """)
+    migrate_legacy_fields(crosstool)
+    output = crosstool.toolchain[0]
+
+    self.assertEqual(output.feature[0].name, "something")
+    self.assertEqual(len(output.feature[0].flag_set[0].flag_group), 2)
+    self.assertEqual(
+        len(output.feature[0].flag_set[0].flag_group[0].expand_if_all_available
+           ), 0)
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[0].flag_group[0]
+        .expand_if_all_available, ["foo"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[0].flag_group[0].flag_group[0]
+        .expand_if_all_available, ["bar"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[0].flag_group[0].flag_group[0]
+        .flag, ["%{foo}"])
+
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0]
+        .expand_if_all_available, ["foo"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0]
+        .expand_if_none_available, ["foo"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0].flag_group[0]
+        .expand_if_none_available, ["bar"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0].flag_group[0]
+        .expand_if_all_available, ["bar"])
+    self.assertEqual(
+        output.feature[0].flag_set[0].flag_group[1].flag_group[0].flag_group[0]
+        .flag, ["%{foo}"])
 
 
 if __name__ == "__main__":