Support for function pointer types.
PiperOrigin-RevId: 429608696
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index a6c0d73..7c846de 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -231,13 +231,9 @@
let lifetime_to_name = HashMap::<LifetimeId, String>::from_iter(
func.lifetime_params.iter().map(|l| (l.id, l.name.clone())),
);
- let return_type_fragment = if func.return_type.rs_type.is_unit_type() {
- quote! {}
- } else {
- let return_type_name = format_rs_type(&func.return_type.rs_type, ir, &lifetime_to_name)
+ let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, ir)
+ .and_then(|t| t.format_as_return_type_fragment(ir, &lifetime_to_name))
.with_context(|| format!("Failed to format return type for {:?}", func))?;
- quote! { -> #return_type_name }
- };
let param_idents =
func.params.iter().map(|p| make_rs_ident(&p.identifier.identifier)).collect_vec();
@@ -1065,6 +1061,7 @@
enum RsTypeKind<'ir> {
Pointer { pointee: Box<RsTypeKind<'ir>>, mutability: Mutability },
Reference { referent: Box<RsTypeKind<'ir>>, mutability: Mutability, lifetime_id: LifetimeId },
+ FuncPtr { abi: &'ir str, return_type: Box<RsTypeKind<'ir>>, param_types: Vec<RsTypeKind<'ir>> },
Record(&'ir Record),
TypeAlias { type_alias: &'ir TypeAlias, underlying_type: Box<RsTypeKind<'ir>> },
Unit,
@@ -1132,7 +1129,21 @@
mutability: Mutability::Const,
lifetime_id: get_lifetime()?,
},
- name => RsTypeKind::Other { name, type_args: get_type_args()? },
+ name => {
+ let mut type_args = get_type_args()?;
+ match name.strip_prefix("#funcPtr ") {
+ None => RsTypeKind::Other { name, type_args },
+ Some(abi) => {
+ // TODO(b/217419782): Consider enforcing `'static` lifetime.
+ ensure!(!type_args.is_empty(), "No return type in fn type: {:?}", ty);
+ RsTypeKind::FuncPtr {
+ abi,
+ return_type: Box::new(type_args.remove(type_args.len() - 1)),
+ param_types: type_args,
+ }
+ },
+ }
+ },
},
};
Ok(result)
@@ -1155,6 +1166,14 @@
let nested_type = referent.format(ir, lifetime_to_name)?;
quote! {& #lifetime #mutability #nested_type}
}
+ RsTypeKind::FuncPtr { abi, return_type, param_types } => {
+ let return_frag = return_type.format_as_return_type_fragment(ir, lifetime_to_name)?;
+ let param_types = param_types
+ .iter()
+ .map(|t| t.format(ir, lifetime_to_name))
+ .collect::<Result<Vec<_>>>()?;
+ quote!{ extern #abi fn( #( #param_types ),* ) #return_frag }
+ },
RsTypeKind::Record(record) => rs_type_name_for_target_and_identifier(
&record.owning_target,
&record.identifier,
@@ -1180,6 +1199,20 @@
Ok(result)
}
+ pub fn format_as_return_type_fragment(
+ &self,
+ ir: &IR,
+ lifetime_to_name: &HashMap<LifetimeId, String>,
+ ) -> Result<TokenStream> {
+ match self {
+ RsTypeKind::Unit => Ok(quote! {}),
+ other_type => {
+ let return_type = other_type.format(ir, lifetime_to_name)?;
+ Ok(quote! { -> #return_type })
+ }
+ }
+ }
+
/// Formats this RsTypeKind as `&'a mut MaybeUninit<SomeStruct>`. This is
/// used to format `__this` parameter in a constructor thunk.
pub fn format_mut_ref_as_uninitialized(
@@ -1249,6 +1282,7 @@
match self {
RsTypeKind::Unit => true,
RsTypeKind::Pointer { .. } => true,
+ RsTypeKind::FuncPtr { .. } => true,
RsTypeKind::Reference { mutability: Mutability::Const, .. } => true,
RsTypeKind::Reference { mutability: Mutability::Mut, .. } => false,
RsTypeKind::Record(record) => should_derive_copy(record),
@@ -1334,6 +1368,10 @@
RsTypeKind::Pointer { pointee, .. } => self.todo.push(pointee),
RsTypeKind::Reference { referent, .. } => self.todo.push(referent),
RsTypeKind::TypeAlias { underlying_type: t, .. } => self.todo.push(t),
+ RsTypeKind::FuncPtr { return_type, param_types, .. } => {
+ self.todo.push(return_type);
+ self.todo.extend(param_types.iter().rev());
+ },
RsTypeKind::Other { type_args, .. } => self.todo.extend(type_args.iter().rev()),
};
Some(curr)
@@ -2044,6 +2082,117 @@
}
#[test]
+ fn test_func_ptr_where_params_are_primitive_types() -> Result<()> {
+ let ir = ir_from_cc(r#" int (*get_ptr_to_func())(float, double); "#)?;
+ let rs_api = generate_rs_api(&ir)?;
+ let rs_api_impl = generate_rs_api_impl(&ir)?;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ #[inline(always)]
+ pub fn get_ptr_to_func() -> Option<extern "C" fn (f32, f64) -> i32> {
+ unsafe { crate::detail::__rust_thunk___Z15get_ptr_to_funcv() }
+ }
+ }
+ );
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ mod detail {
+ #[allow(unused_imports)]
+ use super::*;
+ extern "C" {
+ #[link_name = "_Z15get_ptr_to_funcv"]
+ pub(crate) fn __rust_thunk___Z15get_ptr_to_funcv()
+ -> Option<extern "C" fn(f32, f64) -> i32>;
+ }
+ }
+ }
+ );
+ // Verify that no C++ thunk got generated.
+ assert_cc_not_matches!(rs_api_impl, quote! { __rust_thunk___Z15get_ptr_to_funcv });
+
+ // TODO(b/217419782): Add another test for more exotic calling conventions /
+ // abis.
+
+ // TODO(b/217419782): Add another test for pointer to a function that
+ // takes/returns non-trivially-movable types by value. See also
+ // <internal link>
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_func_ptr_with_non_static_lifetime() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"
+ [[clang::annotate("lifetimes = -> a")]]
+ int (*get_ptr_to_func())(float, double); "#,
+ )?;
+ let rs_api = generate_rs_api(&ir)?;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ // Error while generating bindings for item 'get_ptr_to_func':
+ // Return type is not supported: Function pointers with non-'static lifetimes are not supported: int (*)(float, double)
+ }
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_func_ptr_where_params_are_raw_ptrs() -> Result<()> {
+ let ir = ir_from_cc(r#" const int* (*get_ptr_to_func())(const int*); "#)?;
+ let rs_api = generate_rs_api(&ir)?;
+ let rs_api_impl = generate_rs_api_impl(&ir)?;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ #[inline(always)]
+ pub fn get_ptr_to_func() -> Option<extern "C" fn (*const i32) -> *const i32> {
+ unsafe { crate::detail::__rust_thunk___Z15get_ptr_to_funcv() }
+ }
+ }
+ );
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ mod detail {
+ #[allow(unused_imports)]
+ use super::*;
+ extern "C" {
+ #[link_name = "_Z15get_ptr_to_funcv"]
+ pub(crate) fn __rust_thunk___Z15get_ptr_to_funcv()
+ -> Option<extern "C" fn(*const i32) -> *const i32>;
+ }
+ }
+ }
+ );
+ // Verify that no C++ thunk got generated.
+ assert_cc_not_matches!(rs_api_impl, quote! { __rust_thunk___Z15get_ptr_to_funcv });
+
+ // TODO(b/217419782): Add another test where params (and the return
+ // type) are references with lifetimes. Something like this:
+ // #pragma clang lifetime_elision
+ // const int& (*get_ptr_to_func())(const int&, const int&); "#)?;
+ // 1) Need to investigate why this fails - seeing raw pointers in Rust
+ // seems to indicate that no lifetimes are present at the `importer.cc`
+ // level. Maybe lifetime elision doesn't support this scenario? Unclear
+ // how to explicitly apply [[clang::annotate("lifetimes = a, b -> a")]]
+ // to the _inner_ function.
+ // 2) It is important to have 2 reference parameters, so see if the problem
+ // of passing `lifetimes` by value would have been caught - see:
+ // cl/428079010/depot/rs_bindings_from_cc/
+ // importer.cc?version=s6#823
+
+ // TODO(b/217419782): Decide what to do if the C++ pointer is *not*
+ // annotated with a lifetime - emit `unsafe fn(...) -> ...` in that
+ // case?
+
+ Ok(())
+ }
+
+ #[test]
fn test_item_order() -> Result<()> {
let ir = ir_from_cc(
"int first_func();
@@ -3078,6 +3227,26 @@
}
#[test]
+ fn test_rs_type_kind_dfs_iter_ordering_for_func_ptr() {
+ // Set up a test input representing: fn(A, B) -> C
+ let f = {
+ let a = RsTypeKind::Other { name: "A", type_args: vec![] };
+ let b = RsTypeKind::Other { name: "B", type_args: vec![] };
+ let c = RsTypeKind::Other { name: "C", type_args: vec![] };
+ RsTypeKind::FuncPtr { abi: "blah", param_types: vec![a, b], return_type: Box::new(c) }
+ };
+ let dfs_names = f
+ .dfs_iter()
+ .map(|t| match t {
+ RsTypeKind::FuncPtr { .. } => "fn",
+ RsTypeKind::Other { name, .. } => *name,
+ _ => unreachable!("Only FuncPtr and Other kinds are used in this test"),
+ })
+ .collect_vec();
+ assert_eq!(vec!["fn", "A", "B", "C"], dfs_names);
+ }
+
+ #[test]
fn test_rs_type_kind_lifetimes() -> Result<()> {
let ir = ir_from_cc(
r#"