Emit forward declarations if it helps avoid reordering C++ definitions.
PiperOrigin-RevId: 500400594
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index 4254cb7..d4575ca 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -111,19 +111,21 @@
/// Note that in this particular example the *definition* of `S` does
/// *not* need to appear earlier (and therefore `defs` will *not*
/// contain `LocalDefId` corresponding to `S`).
- // TODO(b/260729464): Implement forward declarations support.
- _fwd_decls: (),
+ fwd_decls: HashSet<LocalDefId>,
}
impl CcPrerequisites {
#[cfg(test)]
fn is_empty(&self) -> bool {
- self.includes.is_empty() && self.defs.is_empty()
+ let &Self { ref includes, ref defs, ref fwd_decls } = self;
+ includes.is_empty() && defs.is_empty() && fwd_decls.is_empty()
}
}
impl AddAssign for CcPrerequisites {
- fn add_assign(&mut self, mut rhs: Self) {
+ fn add_assign(&mut self, rhs: Self) {
+ let Self { mut includes, defs, fwd_decls } = rhs;
+
// `BTreeSet::append` is used because it _seems_ to be more efficient than
// calling `extend`. This is because `extend` takes an iterator
// (processing each `rhs` include one-at-a-time) while `append` steals
@@ -131,9 +133,10 @@
// speculative, since the (expected / guessed) performance difference is
// not documented at
// https://doc.rust-lang.org/std/collections/struct.BTreeSet.html#method.append
- self.includes.append(&mut rhs.includes);
+ self.includes.append(&mut includes);
- self.defs.extend(rhs.defs);
+ self.defs.extend(defs);
+ self.fwd_decls.extend(fwd_decls);
}
}
@@ -358,21 +361,21 @@
Mutability::Mut => quote!{},
Mutability::Not => quote!{ const },
};
- let CcSnippet{ tokens, prereqs } = format_ty_for_cc(tcx, *ty)
+ let CcSnippet{ tokens, mut prereqs } = format_ty_for_cc(tcx, *ty)
.with_context(|| format!(
"Failed to format the pointee of the pointer type `{ty}`"))?;
+ prereqs.fwd_decls.extend(std::mem::take(&mut prereqs.defs));
CcSnippet {
- // TODO(b/260729464): Move `prereqs.defs` to `prereqs.fwd_decls`.
prereqs,
tokens: quote!{ #const_qualifier #tokens * },
}
},
// TODO(b/260268230, b/260729464): When recursively processing nested types (e.g. an
- // element type of an Array, a pointee type of a RawPtr, a referent of a Ref or Slice, a
- // parameter type of an FnPtr, etc), one should also 1) propagate `CcPrerequisites::defs`,
- // 2) cover `CcPrerequisites::defs` in `test_format_ty_for_cc...`. For ptr/ref/slice it
- // might be also desirable to separately track forward-declaration prerequisites.
+ // element type of an Array, a referent of a Ref or Slice, a parameter type of an FnPtr,
+ // etc), one should also 1) propagate `CcPrerequisites::defs`, 2) cover
+ // `CcPrerequisites::defs` in `test_format_ty_for_cc...`. For ptr/ref/slice it might be
+ // also desirable to separately track forward-declaration prerequisites.
| ty::TyKind::Array(..)
| ty::TyKind::Slice(..)
| ty::TyKind::Ref(..)
@@ -676,16 +679,26 @@
///
/// ```
/// quote! {
-/// #header {
+/// #keyword #alignment #name final {
/// #core
-/// #other_parts // (e.g. struct fields)
+/// #other_parts // (e.g. struct fields, methods, etc.)
/// }
/// }
/// ```
+///
+/// `keyword`, `name` are stored separately, to support formatting them as a
+/// forward declaration - e.g. `struct SomeStruct`.
struct AdtCoreBindings {
- /// `header` of the C++ declaration of the ADT.
- /// Example: `struct alignas(4) SomeStruct final`
- header: TokenStream,
+ /// C++ tag - e.g. `struct`, `class`, `enum`, or `union`. This isn't always
+ /// a direct mapping from Rust (e.g. a Rust `enum` might end up being
+ /// represented as an opaque C++ `struct`).
+ keyword: TokenStream,
+
+ /// Alignment declaration - e.g. `alignas(4)`.
+ alignment: TokenStream,
+
+ /// C++ translation of the ADT identifier - e.g. `SomeStruct`.
+ cc_name: TokenStream,
/// `core` contains declarations of
/// - the default constructor
@@ -755,7 +768,6 @@
Literal::u64_unsuffixed(size)
};
- let header = quote! { struct alignas(#alignment) #cc_name final };
let core = quote! {
public:
// TODO(b/258249980): If the wrapped type implements the `Default` trait, then we
@@ -824,7 +836,14 @@
const _: () = assert!(::std::mem::align_of::<#rs_type>() == #alignment);
}
};
- Ok(AdtCoreBindings { header, core, cc_assertions, rs_assertions })
+ Ok(AdtCoreBindings {
+ keyword: quote! { struct },
+ alignment: quote! { alignas(#alignment) },
+ cc_name,
+ core,
+ cc_assertions,
+ rs_assertions,
+ })
}
/// Formats the data (e.g. the fields) of an algebraic data type (an ADT - a
@@ -832,6 +851,9 @@
///
/// This function needs to remain infallible (see the doc comment of
/// `format_adt_core`).
+///
+/// Will panic if `def_id` doesn't identify an ADT that can be successfully
+/// handled by `format_adt_core`.
fn format_adt_data(tcx: TyCtxt, def_id: LocalDefId) -> TokenStream {
let def_id = def_id.to_def_id(); // LocalDefId -> DefId conversion.
let size = get_adt_layout(tcx, def_id)
@@ -849,18 +871,16 @@
/// Formats an algebraic data type (an ADT - a struct, an enum, or a union)
/// represented by `def_id`.
///
-/// Will panic if `def_id`
-/// - is invalid
-/// - doesn't identify an ADT,
+/// Will panic if `def_id` is invalid or doesn't identify an ADT.
fn format_adt(tcx: TyCtxt, local_def_id: LocalDefId) -> Result<MixedSnippet> {
- let AdtCoreBindings { header, core, cc_assertions, rs_assertions: rs} =
+ let AdtCoreBindings { keyword, alignment, cc_name, core, cc_assertions, rs_assertions: rs} =
format_adt_core(tcx, local_def_id.to_def_id())?;
let data = format_adt_data(tcx, local_def_id);
let doc_comment = format_doc_comment(tcx, local_def_id);
let cc = CcSnippet::new(quote! {
__NEWLINE__ #doc_comment
- #header {
+ #keyword #alignment #cc_name final {
#core
#data
};
@@ -870,6 +890,24 @@
Ok(MixedSnippet { cc, rs })
}
+/// Formats the forward declaration of an algebraic data type (an ADT - a
+/// struct, an enum, or a union), returning something like
+/// `quote!{ struct SomeStruct; }`.
+///
+/// Will panic if `def_id` doesn't identify an ADT that can be successfully
+/// handled by `format_adt_core`.
+fn format_fwd_decl(tcx: TyCtxt, def_id: LocalDefId) -> TokenStream {
+ let def_id = def_id.to_def_id(); // LocalDefId -> DefId conversion.
+
+ // `format_fwd_decl` should only be called for items from
+ // `CcPrerequisites::fwd_decls` and `fwd_decls` should only contain ADTs
+ // that `format_adt_core` succeeds for.
+ let AdtCoreBindings { keyword, cc_name, .. } = format_adt_core(tcx, def_id)
+ .expect("`format_fwd_decl` should only be called if `format_adt_core` succeeded");
+
+ quote! { #keyword #cc_name; }
+}
+
/// Formats the doc comment associated with the item identified by
/// `local_def_id`.
/// If there is no associated doc comment, an empty `TokenStream` is returned.
@@ -970,25 +1008,48 @@
// Destructure/rebuild `bindings` (in the same order as `ordered_ids`) into
// `includes`, and into separate C++ snippets and Rust snippets.
- let mut includes = BTreeSet::new();
- let mut ordered_cc = Vec::new();
- let mut rs_body = quote! {};
- for local_def_id in ordered_ids.into_iter() {
- let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
- let MixedSnippet {
- rs: inner_rs,
- cc: CcSnippet {
- tokens: cc_tokens,
- prereqs: CcPrerequisites {
- includes: mut inner_includes,
- .. // `defs` have already been utilized by `toposort` above
+ let (includes, ordered_cc, rs_body) = {
+ let mut already_declared = HashSet::new();
+ let mut fwd_decls = HashSet::new();
+ let mut includes = BTreeSet::new();
+ let mut ordered_cc = Vec::new();
+ let mut rs_body = quote! {};
+ for local_def_id in ordered_ids.into_iter() {
+ let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
+ let MixedSnippet {
+ rs: inner_rs,
+ cc: CcSnippet {
+ tokens: cc_tokens,
+ prereqs: CcPrerequisites {
+ includes: mut inner_includes,
+ fwd_decls: inner_fwd_decls,
+ .. // `defs` have already been utilized by `toposort` above
+ }
}
- }
- } = bindings.remove(&local_def_id).unwrap();
- includes.append(&mut inner_includes);
- ordered_cc.push((mod_path, cc_tokens));
- rs_body.extend(inner_rs);
- }
+ } = bindings.remove(&local_def_id).unwrap();
+
+ fwd_decls.extend(inner_fwd_decls.difference(&already_declared).copied());
+ already_declared.insert(local_def_id);
+ already_declared.extend(inner_fwd_decls.into_iter());
+
+ includes.append(&mut inner_includes);
+ ordered_cc.push((mod_path, cc_tokens));
+ rs_body.extend(inner_rs);
+ }
+
+ // Prepend `fwd_decls` (in the original source order) to `ordered_cc`.
+ let fwd_decls = fwd_decls
+ .into_iter()
+ .sorted_by_key(|def_id| tcx.def_span(*def_id))
+ .map(|local_def_id| {
+ let mod_path = FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path;
+ (mod_path, format_fwd_decl(tcx, local_def_id))
+ })
+ .collect_vec();
+ let ordered_cc = fwd_decls.into_iter().chain(ordered_cc.into_iter()).collect_vec();
+
+ (includes, ordered_cc, rs_body)
+ };
// Generate top-level elements of the C++ header file.
let h_body = {
@@ -1238,6 +1299,141 @@
});
}
+ /// Tests that a forward declaration is present when it is required to
+ /// preserve the original source order.
+ #[test]
+ fn test_generated_bindings_prereq_fwd_decls_required() {
+ let test_src = r#"
+ // To preserve original API order we need to forward declare S.
+ pub fn f(_: *const S) {}
+ pub struct S(bool);
+ "#;
+ test_generated_bindings(test_src, |bindings| {
+ let bindings = bindings.unwrap();
+ assert_cc_matches!(
+ bindings.h_body,
+ quote! {
+ namespace rust_out {
+ ...
+ // Verifing the presence of this forward declaration
+ // it the essence of this test.
+ struct S;
+ ...
+ inline void f(const ::rust_out::S* __param_0) { ... }
+ ...
+ struct alignas(...) S final { ... }
+ ...
+ } // namespace rust_out
+ }
+ );
+ });
+ }
+
+ /// This test verifies that a forward declaration for a given ADT is only
+ /// emitted once (and not once for every API item that requires the
+ /// forward declaration as a prerequisite).
+ #[test]
+ fn test_generated_bindings_prereq_fwd_decls_no_duplication() {
+ let test_src = r#"
+ // All three functions below require a forward declaration of S.
+ pub fn f1(_: *const S) {}
+ pub fn f2(_: *const S) {}
+ pub fn f3(_: *const S) {}
+
+ pub struct S(bool);
+
+ // This function also includes S in its CcPrerequisites::fwd_decls
+ // (although here it is not required, because the definition of S
+ // is already available above).
+ pub fn f4(_: *const S) {}
+ "#;
+ test_generated_bindings(test_src, |bindings| {
+ let bindings = bindings.unwrap().h_body.to_string();
+
+ // Only a single forward declaration is expected.
+ assert_eq!(1, bindings.matches("struct S ;").count(), "bindings = {bindings}");
+ });
+ }
+
+ /// This test verifies that forward declarations are emitted in a
+ /// deterministic order. The particular order doesn't matter _that_
+ /// much, but it definitely shouldn't change every time
+ /// `cc_bindings_from_rs` is invoked again. The current order preserves
+ /// the original source order of the Rust API items.
+ #[test]
+ fn test_generated_bindings_prereq_fwd_decls_deterministic_order() {
+ let test_src = r#"
+ // To try to mix things up, the bindings for the functions below
+ // will *ask* for forward declarations in a different order:
+ // * Different from the order in which the forward declarations
+ // are expected to be *emitted* (the original source order).
+ // * Different from alphabetical order.
+ pub fn f1(_: *const b::S3) {}
+ pub fn f2(_: *const a::S2) {}
+ pub fn f3(_: *const a::S1) {}
+
+ pub mod a {
+ pub struct S1(bool);
+ pub struct S2(bool);
+ }
+
+ pub mod b {
+ pub struct S3(bool);
+ }
+ "#;
+ test_generated_bindings(test_src, |bindings| {
+ let bindings = bindings.unwrap();
+ assert_cc_matches!(
+ bindings.h_body,
+ quote! {
+ namespace rust_out {
+ ...
+ // Verifying that we get the same order in each test
+ // run is the essence of this test.
+ namespace a {
+ struct S1;
+ struct S2;
+ }
+ namespace b {
+ struct S3;
+ }
+ ...
+ inline void f1 ...
+ inline void f2 ...
+ inline void f3 ...
+
+ namespace a { ...
+ struct alignas(...) S1 final { ... } ...
+ struct alignas(...) S2 final { ... } ...
+ } ...
+ namespace b { ...
+ struct alignas(...) S3 final { ... } ...
+ } ...
+ } // namespace rust_out
+ }
+ );
+ });
+ }
+
+ /// This test verifies that forward declarations are not emitted if they are
+ /// not needed (e.g. if bindings the given `struct` or other ADT have
+ /// already been defined earlier). In particular, we don't want to emit
+ /// forward declarations for *all* `structs` (regardless if they are
+ /// needed or not).
+ #[test]
+ fn test_generated_bindings_prereq_fwd_decls_not_needed() {
+ let test_src = r#"
+ pub struct S(bool);
+
+ // S is already defined above - no need for forward declaration in C++.
+ pub fn f(_s: *const S) {}
+ "#;
+ test_generated_bindings(test_src, |bindings| {
+ let bindings = bindings.unwrap();
+ assert_cc_not_matches!(bindings.h_body, quote! { struct S; });
+ });
+ }
+
#[test]
fn test_generated_bindings_module_basics() {
let test_src = r#"
@@ -2628,30 +2824,35 @@
#[test]
fn test_format_ty_for_cc_successes() {
let testcases = [
- // ( <Rust type>, (<expected C++ type>, <expected #include>, <expected prereq def>) )
- ("bool", ("bool", "", "")),
- ("f32", ("float", "", "")),
- ("f64", ("double", "", "")),
- ("i8", ("std::int8_t", "cstdint", "")),
- ("i16", ("std::int16_t", "cstdint", "")),
- ("i32", ("std::int32_t", "cstdint", "")),
- ("i64", ("std::int64_t", "cstdint", "")),
- ("isize", ("std::intptr_t", "cstdint", "")),
- ("u8", ("std::uint8_t", "cstdint", "")),
- ("u16", ("std::uint16_t", "cstdint", "")),
- ("u32", ("std::uint32_t", "cstdint", "")),
- ("u64", ("std::uint64_t", "cstdint", "")),
- ("usize", ("std::uintptr_t", "cstdint", "")),
- ("char", ("std::uint32_t", "cstdint", "")),
- ("SomeStruct", ("::rust_out::SomeStruct", "", "SomeStruct")),
- ("SomeEnum", ("::rust_out::SomeEnum", "", "SomeEnum")),
- ("SomeUnion", ("::rust_out::SomeUnion", "", "SomeUnion")),
- ("*const i32", ("const std::int32_t*", "cstdint", "")),
- ("*mut i32", ("std::int32_t*", "cstdint", "")),
- // TODO(b/260729464): Move `prereqs.defs` expectation to `prereqs.fwd_decls`.
- ("*mut SomeStruct", ("::rust_out::SomeStruct*", "", "SomeStruct")),
+ // ( <Rust type>, (<expected C++ type>,
+ // <expected #include>,
+ // <expected prereq def>,
+ // <expected prereq fwd decl>) )
+ ("bool", ("bool", "", "", "")),
+ ("f32", ("float", "", "", "")),
+ ("f64", ("double", "", "", "")),
+ ("i8", ("std::int8_t", "cstdint", "", "")),
+ ("i16", ("std::int16_t", "cstdint", "", "")),
+ ("i32", ("std::int32_t", "cstdint", "", "")),
+ ("i64", ("std::int64_t", "cstdint", "", "")),
+ ("isize", ("std::intptr_t", "cstdint", "", "")),
+ ("u8", ("std::uint8_t", "cstdint", "", "")),
+ ("u16", ("std::uint16_t", "cstdint", "", "")),
+ ("u32", ("std::uint32_t", "cstdint", "", "")),
+ ("u64", ("std::uint64_t", "cstdint", "", "")),
+ ("usize", ("std::uintptr_t", "cstdint", "", "")),
+ ("char", ("std::uint32_t", "cstdint", "", "")),
+ ("SomeStruct", ("::rust_out::SomeStruct", "", "SomeStruct", "")),
+ ("SomeEnum", ("::rust_out::SomeEnum", "", "SomeEnum", "")),
+ ("SomeUnion", ("::rust_out::SomeUnion", "", "SomeUnion", "")),
+ ("*const i32", ("const std::int32_t*", "cstdint", "", "")),
+ ("*mut i32", ("std::int32_t*", "cstdint", "", "")),
+ // `SomeStruct` is a `fwd_decls` prerequisite (not `defs` prerequisite):
+ ("*mut SomeStruct", ("::rust_out::SomeStruct*", "", "", "SomeStruct")),
+ // Testing propagation of deeper/nested `fwd_decls`:
+ ("*mut *mut SomeStruct", (":: rust_out :: SomeStruct * *", "", "", "SomeStruct")),
// Extra parens/sugar are expected to be ignored:
- ("(bool)", ("bool", "", "")),
+ ("(bool)", ("bool", "", "", "")),
];
let preamble = quote! {
#![allow(unused_parens)]
@@ -2672,24 +2873,27 @@
test_ty(
&testcases,
preamble,
- |desc, tcx, ty, (expected_tokens, expected_include, expected_prereq_def)| {
- let (actual_tokens, actual_includes, actual_prereq_defs) = {
+ |desc, tcx, ty,
+ (expected_tokens, expected_include, expected_prereq_def, expected_prereq_fwd_decl)| {
+ let (actual_tokens, actual_prereqs) = {
let s = format_ty_for_cc(tcx, ty).unwrap();
- (s.tokens.to_string(), s.prereqs.includes, s.prereqs.defs)
- };
+ (s.tokens.to_string(), s.prereqs)
+ };
+ let (actual_includes, actual_prereq_defs, actual_prereq_fwd_decls) =
+ (actual_prereqs.includes, actual_prereqs.defs, actual_prereqs.fwd_decls);
let expected_tokens = expected_tokens.parse::<TokenStream>().unwrap().to_string();
assert_eq!(actual_tokens, expected_tokens, "{desc}");
- if expected_include.is_empty() {
- assert!(actual_includes.is_empty());
- } else {
- let expected_header = format_cc_ident(expected_include).unwrap();
- assert_cc_matches!(
- format_cc_includes(&actual_includes),
- quote! { include <#expected_header> }
- );
- }
+ if expected_include.is_empty() {
+ assert!(actual_includes.is_empty());
+ } else {
+ let expected_header = format_cc_ident(expected_include).unwrap();
+ assert_cc_matches!(
+ format_cc_includes(&actual_includes),
+ quote! { include <#expected_header> }
+ );
+ }
if expected_prereq_def.is_empty() {
assert!(actual_prereq_defs.is_empty());
@@ -2698,6 +2902,15 @@
assert_eq!(1, actual_prereq_defs.len());
assert_eq!(expected_def_id, actual_prereq_defs.into_iter().next().unwrap());
}
+
+ if expected_prereq_fwd_decl.is_empty() {
+ assert!(actual_prereq_fwd_decls.is_empty());
+ } else {
+ let expected_def_id = find_def_id_by_name(tcx, expected_prereq_fwd_decl);
+ assert_eq!(1, actual_prereq_fwd_decls.len());
+ assert_eq!(expected_def_id,
+ actual_prereq_fwd_decls.into_iter().next().unwrap());
+ }
},
);
}
diff --git a/cc_bindings_from_rs/test/structs/structs.rs b/cc_bindings_from_rs/test/structs/structs.rs
index 8102b65..a1b4c2a 100644
--- a/cc_bindings_from_rs/test/structs/structs.rs
+++ b/cc_bindings_from_rs/test/structs/structs.rs
@@ -93,3 +93,45 @@
}
}
}
+
+/// This module provides coverage for emitting forward declarations. In
+/// particular, if we assume that the C++ bindings are emitted in the same order
+/// as the Rust items below, then `S1` needs to be forward-declared (because
+/// `get_int_from_s1` is *before* `S1`).
+///
+/// TODO(b/260725687): Using a cycle below should avoid the assumption above
+/// about preserving the same order (because a cycle can't be
+/// toposorted/reordered). OTOH forming a cycle seems to depend on supporting
+/// bindings for additional language features - either static methods
+/// (b/260725279):
+/// ```
+/// // Cycle!:
+/// pub struct S1(i32);
+/// pub struct S2(i32);
+/// impl S1 {
+/// pub fn get_int_from_s2(s2: *const S2) { ... }
+/// }
+/// impl S2 {
+/// pub fn get_int_from_s1(s1: *const S1) { ... }
+/// }
+/// ```
+/// or fields (b/258233850):
+/// ```
+/// // Cycle!:
+/// pub struct S1 {
+/// ptr_to_s2: *const S2,
+/// }
+/// pub struct S2 {
+/// ptr_to_s1: *const S1,
+/// }
+/// ```
+pub mod fwd_decls {
+ pub fn get_int_from_s1(s1: *const S1) -> i32 {
+ #![allow(clippy::not_unsafe_ptr_arg_deref)]
+ unsafe { (*s1).0 }
+ }
+ pub fn create_s1() -> S1 {
+ S1(456)
+ }
+ pub struct S1(i32);
+}
diff --git a/cc_bindings_from_rs/test/structs/structs_test.cc b/cc_bindings_from_rs/test/structs/structs_test.cc
index 9bca5d7..b952fea 100644
--- a/cc_bindings_from_rs/test/structs/structs_test.cc
+++ b/cc_bindings_from_rs/test/structs/structs_test.cc
@@ -32,5 +32,11 @@
EXPECT_EQ(123, m1::get_int_from_s2(std::move(s2)));
}
+TEST(StructsTest, FwdDecls) {
+ namespace fwd_decls = structs::fwd_decls;
+ fwd_decls::S1 s1 = fwd_decls::create_s1();
+ EXPECT_EQ(456, fwd_decls::get_int_from_s1(&s1));
+}
+
} // namespace
} // namespace crubit