Emit forward declarations if it helps avoid reordering C++ definitions.

PiperOrigin-RevId: 500400594
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index 4254cb7..d4575ca 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -111,19 +111,21 @@
     /// Note that in this particular example the *definition* of `S` does
     /// *not* need to appear earlier (and therefore `defs` will *not*
     /// contain `LocalDefId` corresponding to `S`).
-    // TODO(b/260729464): Implement forward declarations support.
-    _fwd_decls: (),
+    fwd_decls: HashSet<LocalDefId>,
 }
 
 impl CcPrerequisites {
     #[cfg(test)]
     fn is_empty(&self) -> bool {
-        self.includes.is_empty() && self.defs.is_empty()
+        let &Self { ref includes, ref defs, ref fwd_decls } = self;
+        includes.is_empty() && defs.is_empty() && fwd_decls.is_empty()
     }
 }
 
 impl AddAssign for CcPrerequisites {
-    fn add_assign(&mut self, mut rhs: Self) {
+    fn add_assign(&mut self, rhs: Self) {
+        let Self { mut includes, defs, fwd_decls } = rhs;
+
         // `BTreeSet::append` is used because it _seems_ to be more efficient than
         // calling `extend`.  This is because `extend` takes an iterator
         // (processing each `rhs` include one-at-a-time) while `append` steals
@@ -131,9 +133,10 @@
         // speculative, since the (expected / guessed) performance difference is
         // not documented at
         // https://doc.rust-lang.org/std/collections/struct.BTreeSet.html#method.append
-        self.includes.append(&mut rhs.includes);
+        self.includes.append(&mut includes);
 
-        self.defs.extend(rhs.defs);
+        self.defs.extend(defs);
+        self.fwd_decls.extend(fwd_decls);
     }
 }
 
@@ -358,21 +361,21 @@
                 Mutability::Mut => quote!{},
                 Mutability::Not => quote!{ const },
             };
-            let CcSnippet{ tokens, prereqs } = format_ty_for_cc(tcx, *ty)
+            let CcSnippet{ tokens, mut prereqs } = format_ty_for_cc(tcx, *ty)
                 .with_context(|| format!(
                         "Failed to format the pointee of the pointer type `{ty}`"))?;
+            prereqs.fwd_decls.extend(std::mem::take(&mut prereqs.defs));
             CcSnippet {
-                // TODO(b/260729464): Move `prereqs.defs` to `prereqs.fwd_decls`.
                 prereqs,
                 tokens: quote!{ #const_qualifier #tokens * },
             }
         },
 
         // TODO(b/260268230, b/260729464): When recursively processing nested types (e.g. an
-        // element type of an Array, a pointee type of a RawPtr, a referent of a Ref or Slice, a
-        // parameter type of an FnPtr, etc), one should also 1) propagate `CcPrerequisites::defs`,
-        // 2) cover `CcPrerequisites::defs` in `test_format_ty_for_cc...`.  For ptr/ref/slice it
-        // might be also desirable to separately track forward-declaration prerequisites.
+        // element type of an Array, a referent of a Ref or Slice, a parameter type of an FnPtr,
+        // etc), one should also 1) propagate `CcPrerequisites::defs`, 2) cover
+        // `CcPrerequisites::defs` in `test_format_ty_for_cc...`.  For ptr/ref/slice it might be
+        // also desirable to separately track forward-declaration prerequisites.
         | ty::TyKind::Array(..)
         | ty::TyKind::Slice(..)
         | ty::TyKind::Ref(..)
@@ -676,16 +679,26 @@
 ///
 /// ```
 /// quote! {
-///     #header {
+///     #keyword #alignment #name final {
 ///         #core
-///         #other_parts  // (e.g. struct fields)
+///         #other_parts  // (e.g. struct fields, methods, etc.)
 ///     }
 /// }
 /// ```
+///
+/// `keyword`, `name` are stored separately, to support formatting them as a
+/// forward declaration - e.g. `struct SomeStruct`.
 struct AdtCoreBindings {
-    /// `header` of the C++ declaration of the ADT.
-    /// Example: `struct alignas(4) SomeStruct final`
-    header: TokenStream,
+    /// C++ tag - e.g. `struct`, `class`, `enum`, or `union`.  This isn't always
+    /// a direct mapping from Rust (e.g. a Rust `enum` might end up being
+    /// represented as an opaque C++ `struct`).
+    keyword: TokenStream,
+
+    /// Alignment declaration - e.g. `alignas(4)`.
+    alignment: TokenStream,
+
+    /// C++ translation of the ADT identifier - e.g. `SomeStruct`.
+    cc_name: TokenStream,
 
     /// `core` contains declarations of
     /// - the default constructor
@@ -755,7 +768,6 @@
         Literal::u64_unsuffixed(size)
     };
 
-    let header = quote! { struct alignas(#alignment) #cc_name final };
     let core = quote! {
         public:
             // TODO(b/258249980): If the wrapped type implements the `Default` trait, then we
@@ -824,7 +836,14 @@
             const _: () = assert!(::std::mem::align_of::<#rs_type>() == #alignment);
         }
     };
-    Ok(AdtCoreBindings { header, core, cc_assertions, rs_assertions })
+    Ok(AdtCoreBindings {
+        keyword: quote! { struct },
+        alignment: quote! { alignas(#alignment) },
+        cc_name,
+        core,
+        cc_assertions,
+        rs_assertions,
+    })
 }
 
 /// Formats the data (e.g. the fields) of an algebraic data type (an ADT - a
@@ -832,6 +851,9 @@
 ///
 /// This function needs to remain infallible (see the doc comment of
 /// `format_adt_core`).
+///
+/// Will panic if `def_id` doesn't identify an ADT that can be successfully
+/// handled by `format_adt_core`.
 fn format_adt_data(tcx: TyCtxt, def_id: LocalDefId) -> TokenStream {
     let def_id = def_id.to_def_id(); // LocalDefId -> DefId conversion.
     let size = get_adt_layout(tcx, def_id)
@@ -849,18 +871,16 @@
 /// Formats an algebraic data type (an ADT - a struct, an enum, or a union)
 /// represented by `def_id`.
 ///
-/// Will panic if `def_id`
-/// - is invalid
-/// - doesn't identify an ADT,
+/// Will panic if `def_id` is invalid or doesn't identify an ADT.
 fn format_adt(tcx: TyCtxt, local_def_id: LocalDefId) -> Result<MixedSnippet> {
-    let AdtCoreBindings { header, core, cc_assertions, rs_assertions: rs} =
+    let AdtCoreBindings { keyword, alignment, cc_name, core, cc_assertions, rs_assertions: rs} =
         format_adt_core(tcx, local_def_id.to_def_id())?;
 
     let data = format_adt_data(tcx, local_def_id);
     let doc_comment = format_doc_comment(tcx, local_def_id);
     let cc = CcSnippet::new(quote! {
         __NEWLINE__ #doc_comment
-        #header {
+        #keyword #alignment #cc_name final {
             #core
             #data
         };
@@ -870,6 +890,24 @@
     Ok(MixedSnippet { cc, rs })
 }
 
+/// Formats the forward declaration of an algebraic data type (an ADT - a
+/// struct, an enum, or a union), returning something like
+/// `quote!{ struct SomeStruct; }`.
+///
+/// Will panic if `def_id` doesn't identify an ADT that can be successfully
+/// handled by `format_adt_core`.
+fn format_fwd_decl(tcx: TyCtxt, def_id: LocalDefId) -> TokenStream {
+    let def_id = def_id.to_def_id(); // LocalDefId -> DefId conversion.
+
+    // `format_fwd_decl` should only be called for items from
+    // `CcPrerequisites::fwd_decls` and `fwd_decls` should only contain ADTs
+    // that `format_adt_core` succeeds for.
+    let AdtCoreBindings { keyword, cc_name, .. } = format_adt_core(tcx, def_id)
+        .expect("`format_fwd_decl` should only be called if `format_adt_core` succeeded");
+
+    quote! { #keyword #cc_name; }
+}
+
 /// Formats the doc comment associated with the item identified by
 /// `local_def_id`.
 /// If there is no associated doc comment, an empty `TokenStream` is returned.
@@ -970,25 +1008,48 @@
 
     // Destructure/rebuild `bindings` (in the same order as `ordered_ids`) into
     // `includes`, and into separate C++ snippets and Rust snippets.
-    let mut includes = BTreeSet::new();
-    let mut ordered_cc = Vec::new();
-    let mut rs_body = quote! {};
-    for local_def_id in ordered_ids.into_iter() {
-        let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
-        let MixedSnippet {
-            rs: inner_rs,
-            cc: CcSnippet {
-                tokens: cc_tokens,
-                prereqs: CcPrerequisites {
-                    includes: mut inner_includes,
-                    .. // `defs` have already been utilized by `toposort` above
+    let (includes, ordered_cc, rs_body) = {
+        let mut already_declared = HashSet::new();
+        let mut fwd_decls = HashSet::new();
+        let mut includes = BTreeSet::new();
+        let mut ordered_cc = Vec::new();
+        let mut rs_body = quote! {};
+        for local_def_id in ordered_ids.into_iter() {
+            let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
+            let MixedSnippet {
+                rs: inner_rs,
+                cc: CcSnippet {
+                    tokens: cc_tokens,
+                    prereqs: CcPrerequisites {
+                        includes: mut inner_includes,
+                        fwd_decls: inner_fwd_decls,
+                        .. // `defs` have already been utilized by `toposort` above
+                    }
                 }
-            }
-        } = bindings.remove(&local_def_id).unwrap();
-        includes.append(&mut inner_includes);
-        ordered_cc.push((mod_path, cc_tokens));
-        rs_body.extend(inner_rs);
-    }
+            } = bindings.remove(&local_def_id).unwrap();
+
+            fwd_decls.extend(inner_fwd_decls.difference(&already_declared).copied());
+            already_declared.insert(local_def_id);
+            already_declared.extend(inner_fwd_decls.into_iter());
+
+            includes.append(&mut inner_includes);
+            ordered_cc.push((mod_path, cc_tokens));
+            rs_body.extend(inner_rs);
+        }
+
+        // Prepend `fwd_decls` (in the original source order) to `ordered_cc`.
+        let fwd_decls = fwd_decls
+            .into_iter()
+            .sorted_by_key(|def_id| tcx.def_span(*def_id))
+            .map(|local_def_id| {
+                let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
+                (mod_path, format_fwd_decl(tcx, local_def_id))
+            })
+            .collect_vec();
+        let ordered_cc = fwd_decls.into_iter().chain(ordered_cc.into_iter()).collect_vec();
+
+        (includes, ordered_cc, rs_body)
+    };
 
     // Generate top-level elements of the C++ header file.
     let h_body = {
@@ -1238,6 +1299,141 @@
         });
     }
 
+    /// Tests that a forward declaration is present when it is required to
+    /// preserve the original source order.
+    #[test]
+    fn test_generated_bindings_prereq_fwd_decls_required() {
+        let test_src = r#"
+                // To preserve original API order we need to forward declare S.
+                pub fn f(_: *const S) {}
+                pub struct S(bool);
+            "#;
+        test_generated_bindings(test_src, |bindings| {
+            let bindings = bindings.unwrap();
+            assert_cc_matches!(
+                bindings.h_body,
+                quote! {
+                    namespace rust_out {
+                        ...
+                        // Verifing the presence of this forward declaration
+                        // it the essence of this test.
+                        struct S;
+                        ...
+                        inline void f(const ::rust_out::S* __param_0) { ... }
+                        ...
+                        struct alignas(...) S final { ... }
+                        ...
+                    }  // namespace rust_out
+                }
+            );
+        });
+    }
+
+    /// This test verifies that a forward declaration for a given ADT is only
+    /// emitted once (and not once for every API item that requires the
+    /// forward declaration as a prerequisite).
+    #[test]
+    fn test_generated_bindings_prereq_fwd_decls_no_duplication() {
+        let test_src = r#"
+                // All three functions below require a forward declaration of S.
+                pub fn f1(_: *const S) {}
+                pub fn f2(_: *const S) {}
+                pub fn f3(_: *const S) {}
+
+                pub struct S(bool);
+
+                // This function also includes S in its CcPrerequisites::fwd_decls
+                // (although here it is not required, because the definition of S
+                // is already available above).
+                pub fn f4(_: *const S) {}
+            "#;
+        test_generated_bindings(test_src, |bindings| {
+            let bindings = bindings.unwrap().h_body.to_string();
+
+            // Only a single forward declaration is expected.
+            assert_eq!(1, bindings.matches("struct S ;").count(), "bindings = {bindings}");
+        });
+    }
+
+    /// This test verifies that forward declarations are emitted in a
+    /// deterministic order. The particular order doesn't matter _that_
+    /// much, but it definitely shouldn't change every time
+    /// `cc_bindings_from_rs` is invoked again.  The current order preserves
+    /// the original source order of the Rust API items.
+    #[test]
+    fn test_generated_bindings_prereq_fwd_decls_deterministic_order() {
+        let test_src = r#"
+                // To try to mix things up, the bindings for the functions below
+                // will *ask* for forward declarations in a different order:
+                // * Different from the order in which the forward declarations
+                //   are expected to be *emitted* (the original source order).
+                // * Different from alphabetical order.
+                pub fn f1(_: *const b::S3) {}
+                pub fn f2(_: *const a::S2) {}
+                pub fn f3(_: *const a::S1) {}
+
+                pub mod a {
+                    pub struct S1(bool);
+                    pub struct S2(bool);
+                }
+
+                pub mod b {
+                    pub struct S3(bool);
+                }
+            "#;
+        test_generated_bindings(test_src, |bindings| {
+            let bindings = bindings.unwrap();
+            assert_cc_matches!(
+                bindings.h_body,
+                quote! {
+                    namespace rust_out {
+                        ...
+                        // Verifying that we get the same order in each test
+                        // run is the essence of this test.
+                        namespace a {
+                        struct S1;
+                        struct S2;
+                        }
+                        namespace b {
+                        struct S3;
+                        }
+                        ...
+                        inline void f1 ...
+                        inline void f2 ...
+                        inline void f3 ...
+
+                        namespace a { ...
+                        struct alignas(...) S1 final { ... } ...
+                        struct alignas(...) S2 final { ... } ...
+                        } ...
+                        namespace b { ...
+                        struct alignas(...) S3 final { ... } ...
+                        } ...
+                    }  // namespace rust_out
+                }
+            );
+        });
+    }
+
+    /// This test verifies that forward declarations are not emitted if they are
+    /// not needed (e.g. if bindings the given `struct` or other ADT have
+    /// already been defined earlier).  In particular, we don't want to emit
+    /// forward declarations for *all* `structs` (regardless if they are
+    /// needed or not).
+    #[test]
+    fn test_generated_bindings_prereq_fwd_decls_not_needed() {
+        let test_src = r#"
+                pub struct S(bool);
+
+                // S is already defined above - no need for forward declaration in C++.
+                pub fn f(_s: *const S) {}
+            "#;
+        test_generated_bindings(test_src, |bindings| {
+            let bindings = bindings.unwrap();
+            assert_cc_not_matches!(bindings.h_body, quote! { struct S; });
+        });
+    }
+
     #[test]
     fn test_generated_bindings_module_basics() {
         let test_src = r#"
@@ -2628,30 +2824,35 @@
     #[test]
     fn test_format_ty_for_cc_successes() {
         let testcases = [
-            // ( <Rust type>, (<expected C++ type>, <expected #include>, <expected prereq def>) )
-            ("bool", ("bool", "", "")),
-            ("f32", ("float", "", "")),
-            ("f64", ("double", "", "")),
-            ("i8", ("std::int8_t", "cstdint", "")),
-            ("i16", ("std::int16_t", "cstdint", "")),
-            ("i32", ("std::int32_t", "cstdint", "")),
-            ("i64", ("std::int64_t", "cstdint", "")),
-            ("isize", ("std::intptr_t", "cstdint", "")),
-            ("u8", ("std::uint8_t", "cstdint", "")),
-            ("u16", ("std::uint16_t", "cstdint", "")),
-            ("u32", ("std::uint32_t", "cstdint", "")),
-            ("u64", ("std::uint64_t", "cstdint", "")),
-            ("usize", ("std::uintptr_t", "cstdint", "")),
-            ("char", ("std::uint32_t", "cstdint", "")),
-            ("SomeStruct", ("::rust_out::SomeStruct", "", "SomeStruct")),
-            ("SomeEnum", ("::rust_out::SomeEnum", "", "SomeEnum")),
-            ("SomeUnion", ("::rust_out::SomeUnion", "", "SomeUnion")),
-            ("*const i32", ("const std::int32_t*", "cstdint", "")),
-            ("*mut i32", ("std::int32_t*", "cstdint", "")),
-            // TODO(b/260729464): Move `prereqs.defs` expectation to `prereqs.fwd_decls`.
-            ("*mut SomeStruct", ("::rust_out::SomeStruct*", "", "SomeStruct")),
+            // ( <Rust type>, (<expected C++ type>,
+            //                 <expected #include>,
+            //                 <expected prereq def>,
+            //                 <expected prereq fwd decl>) )
+            ("bool", ("bool", "", "", "")),
+            ("f32", ("float", "", "", "")),
+            ("f64", ("double", "", "", "")),
+            ("i8", ("std::int8_t", "cstdint", "", "")),
+            ("i16", ("std::int16_t", "cstdint", "", "")),
+            ("i32", ("std::int32_t", "cstdint", "", "")),
+            ("i64", ("std::int64_t", "cstdint", "", "")),
+            ("isize", ("std::intptr_t", "cstdint", "", "")),
+            ("u8", ("std::uint8_t", "cstdint", "", "")),
+            ("u16", ("std::uint16_t", "cstdint", "", "")),
+            ("u32", ("std::uint32_t", "cstdint", "", "")),
+            ("u64", ("std::uint64_t", "cstdint", "", "")),
+            ("usize", ("std::uintptr_t", "cstdint", "", "")),
+            ("char", ("std::uint32_t", "cstdint", "", "")),
+            ("SomeStruct", ("::rust_out::SomeStruct", "", "SomeStruct", "")),
+            ("SomeEnum", ("::rust_out::SomeEnum", "", "SomeEnum", "")),
+            ("SomeUnion", ("::rust_out::SomeUnion", "", "SomeUnion", "")),
+            ("*const i32", ("const std::int32_t*", "cstdint", "", "")),
+            ("*mut i32", ("std::int32_t*", "cstdint", "", "")),
+            // `SomeStruct` is a `fwd_decls` prerequisite (not `defs` prerequisite):
+            ("*mut SomeStruct", ("::rust_out::SomeStruct*", "", "", "SomeStruct")),
+            // Testing propagation of deeper/nested `fwd_decls`:
+            ("*mut *mut SomeStruct", (":: rust_out :: SomeStruct * *", "", "", "SomeStruct")),
             // Extra parens/sugar are expected to be ignored:
-            ("(bool)", ("bool", "", "")),
+            ("(bool)", ("bool", "", "", "")),
         ];
         let preamble = quote! {
             #![allow(unused_parens)]
@@ -2672,24 +2873,27 @@
         test_ty(
             &testcases,
             preamble,
-            |desc, tcx, ty, (expected_tokens, expected_include, expected_prereq_def)| {
-                let (actual_tokens, actual_includes, actual_prereq_defs) = {
+            |desc, tcx, ty,
+             (expected_tokens, expected_include, expected_prereq_def, expected_prereq_fwd_decl)| {
+                let (actual_tokens, actual_prereqs) = {
                     let s = format_ty_for_cc(tcx, ty).unwrap();
-                    (s.tokens.to_string(), s.prereqs.includes, s.prereqs.defs)
-            };
+                    (s.tokens.to_string(), s.prereqs)
+                };
+                let (actual_includes, actual_prereq_defs, actual_prereq_fwd_decls) =
+                    (actual_prereqs.includes, actual_prereqs.defs, actual_prereqs.fwd_decls);
 
                 let expected_tokens = expected_tokens.parse::<TokenStream>().unwrap().to_string();
                 assert_eq!(actual_tokens, expected_tokens, "{desc}");
 
-            if expected_include.is_empty() {
-                assert!(actual_includes.is_empty());
-            } else {
-                let expected_header = format_cc_ident(expected_include).unwrap();
-                assert_cc_matches!(
-                    format_cc_includes(&actual_includes),
-                    quote! { include <#expected_header> }
-                );
-            }
+                if expected_include.is_empty() {
+                    assert!(actual_includes.is_empty());
+                } else {
+                    let expected_header = format_cc_ident(expected_include).unwrap();
+                    assert_cc_matches!(
+                        format_cc_includes(&actual_includes),
+                        quote! { include <#expected_header> }
+                    );
+                }
 
                 if expected_prereq_def.is_empty() {
                     assert!(actual_prereq_defs.is_empty());
@@ -2698,6 +2902,15 @@
                     assert_eq!(1, actual_prereq_defs.len());
                     assert_eq!(expected_def_id, actual_prereq_defs.into_iter().next().unwrap());
                 }
+
+                if expected_prereq_fwd_decl.is_empty() {
+                    assert!(actual_prereq_fwd_decls.is_empty());
+                } else {
+                    let expected_def_id = find_def_id_by_name(tcx, expected_prereq_fwd_decl);
+                    assert_eq!(1, actual_prereq_fwd_decls.len());
+                    assert_eq!(expected_def_id,
+                               actual_prereq_fwd_decls.into_iter().next().unwrap());
+                }
             },
         );
     }
diff --git a/cc_bindings_from_rs/test/structs/structs.rs b/cc_bindings_from_rs/test/structs/structs.rs
index 8102b65..a1b4c2a 100644
--- a/cc_bindings_from_rs/test/structs/structs.rs
+++ b/cc_bindings_from_rs/test/structs/structs.rs
@@ -93,3 +93,45 @@
         }
     }
 }
+
+/// This module provides coverage for emitting forward declarations.  In
+/// particular, if we assume that the C++ bindings are emitted in the same order
+/// as the Rust items below, then `S1` needs to be forward-declared (because
+/// `get_int_from_s1` is *before* `S1`).
+///
+/// TODO(b/260725687): Using a cycle below should avoid the assumption above
+/// about preserving the same order (because a cycle can't be
+/// toposorted/reordered).  OTOH forming a cycle seems to depend on supporting
+/// bindings for additional language features - either static methods
+/// (b/260725279):
+/// ```
+///     // Cycle!:
+///     pub struct S1(i32);
+///     pub struct S2(i32);
+///     impl S1 {
+///         pub fn get_int_from_s2(s2: *const S2) { ... }
+///     }
+///     impl S2 {
+///         pub fn get_int_from_s1(s1: *const S1) { ... }
+///     }
+/// ```
+/// or fields (b/258233850):
+/// ```
+///     // Cycle!:
+///     pub struct S1 {
+///         ptr_to_s2: *const S2,
+///     }
+///     pub struct S2 {
+///         ptr_to_s1: *const S1,
+///     }
+/// ```
+pub mod fwd_decls {
+    pub fn get_int_from_s1(s1: *const S1) -> i32 {
+        #![allow(clippy::not_unsafe_ptr_arg_deref)]
+        unsafe { (*s1).0 }
+    }
+    pub fn create_s1() -> S1 {
+        S1(456)
+    }
+    pub struct S1(i32);
+}
diff --git a/cc_bindings_from_rs/test/structs/structs_test.cc b/cc_bindings_from_rs/test/structs/structs_test.cc
index 9bca5d7..b952fea 100644
--- a/cc_bindings_from_rs/test/structs/structs_test.cc
+++ b/cc_bindings_from_rs/test/structs/structs_test.cc
@@ -32,5 +32,11 @@
   EXPECT_EQ(123, m1::get_int_from_s2(std::move(s2)));
 }
 
+TEST(StructsTest, FwdDecls) {
+  namespace fwd_decls = structs::fwd_decls;
+  fwd_decls::S1 s1 = fwd_decls::create_s1();
+  EXPECT_EQ(456, fwd_decls::get_int_from_s1(&s1));
+}
+
 }  // namespace
 }  // namespace crubit