Support for function references.
Earlier CLs added support for function *pointers*. This CL adds support
for function *references.
PiperOrigin-RevId: 432488140
diff --git a/rs_bindings_from_cc/importer.cc b/rs_bindings_from_cc/importer.cc
index 0490ad7..3096532 100644
--- a/rs_bindings_from_cc/importer.cc
+++ b/rs_bindings_from_cc/importer.cc
@@ -948,21 +948,22 @@
if (auto maybe_mapped_type = MapKnownCcTypeToRsType(type_string);
maybe_mapped_type.has_value()) {
type = MappedType::Simple(std::string(*maybe_mapped_type), type_string);
- } else if (const auto* pointer_type =
- qual_type->getAs<clang::PointerType>()) {
+ } else if (qual_type->isPointerType() || qual_type->isLValueReferenceType()) {
+ clang::QualType pointee_type = qual_type->getPointeeType();
+ std::optional<LifetimeId> lifetime;
+ if (lifetimes.has_value()) {
+ CHECK(!lifetimes->empty());
+ lifetime = LifetimeId(lifetimes->back().Id());
+ lifetimes->pop_back();
+ }
if (const auto* func_type =
- pointer_type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
- std::optional<LifetimeId> lifetime;
- if (lifetimes.has_value()) {
- CHECK(!lifetimes->empty());
- if (lifetimes->back() != devtools_rust::Lifetime::Static()) {
- return absl::UnimplementedError(
- absl::StrCat("Function pointers with non-'static lifetimes are "
- "not supported: ",
- type_string));
- }
- lifetime = LifetimeId(lifetimes->back().Id());
- lifetimes->pop_back();
+ pointee_type->getAs<clang::FunctionProtoType>()) {
+ if (lifetime.has_value() &&
+ lifetime->value() != devtools_rust::Lifetime::Static().Id()) {
+ return absl::UnimplementedError(
+ absl::StrCat("Function pointers with non-'static lifetimes are "
+ "not supported: ",
+ type_string));
}
do {
clang::StringRef cc_call_conv =
@@ -981,34 +982,26 @@
param_types.push_back(*param_type_status);
}
- type = MappedType::FuncPtr(cc_call_conv, *rs_abi, lifetime,
- *return_type, param_types);
+ if (qual_type->isPointerType()) {
+ type = MappedType::FuncPtr(cc_call_conv, *rs_abi, lifetime,
+ *return_type, param_types);
+ } else {
+ DCHECK(qual_type->isLValueReferenceType());
+ type = MappedType::FuncRef(cc_call_conv, *rs_abi, lifetime,
+ *return_type, param_types);
+ }
} while (false);
} else {
- std::optional<LifetimeId> lifetime;
- if (lifetimes.has_value()) {
- CHECK(!lifetimes->empty());
- lifetime = LifetimeId(lifetimes->back().Id());
- lifetimes->pop_back();
+ auto mapped_pointee_type = ConvertType(pointee_type, lifetimes);
+ if (mapped_pointee_type.ok()) {
+ if (qual_type->isPointerType()) {
+ type =
+ MappedType::PointerTo(*mapped_pointee_type, lifetime, nullable);
+ } else {
+ DCHECK(qual_type->isLValueReferenceType());
+ type = MappedType::LValueReferenceTo(*mapped_pointee_type, lifetime);
+ }
}
- auto pointee_type =
- ConvertType(pointer_type->getPointeeType(), lifetimes);
- if (pointee_type.ok()) {
- type = MappedType::PointerTo(*pointee_type, lifetime, nullable);
- }
- }
- } else if (const auto* lvalue_ref_type =
- qual_type->getAs<clang::LValueReferenceType>()) {
- std::optional<LifetimeId> lifetime;
- if (lifetimes.has_value()) {
- CHECK(!lifetimes->empty());
- lifetime = LifetimeId(lifetimes->back().Id());
- lifetimes->pop_back();
- }
- auto pointee_type =
- ConvertType(lvalue_ref_type->getPointeeType(), lifetimes);
- if (pointee_type.ok()) {
- type = MappedType::LValueReferenceTo(*pointee_type, lifetime);
}
} else if (const auto* builtin_type =
// Use getAsAdjusted instead of getAs so we don't desugar
diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc
index 900afc0..ec60826 100644
--- a/rs_bindings_from_cc/ir.cc
+++ b/rs_bindings_from_cc/ir.cc
@@ -128,6 +128,26 @@
std::optional<LifetimeId> lifetime,
MappedType return_type,
std::vector<MappedType> param_types) {
+ MappedType result = FuncRef(cc_call_conv, rs_abi, lifetime,
+ std::move(return_type), std::move(param_types));
+
+ DCHECK_EQ(result.cc_type.name, internal::kCcLValueRef);
+ result.cc_type.name = std::string(internal::kCcPtr);
+
+ RsType rs_func_ptr_type = std::move(result.rs_type);
+ DCHECK_EQ(rs_func_ptr_type.name.substr(0, internal::kRustFuncPtr.length()),
+ internal::kRustFuncPtr);
+ result.rs_type =
+ RsType{.name = "Option", .type_args = {std::move(rs_func_ptr_type)}};
+
+ return result;
+}
+
+MappedType MappedType::FuncRef(absl::string_view cc_call_conv,
+ absl::string_view rs_abi,
+ std::optional<LifetimeId> lifetime,
+ MappedType return_type,
+ std::vector<MappedType> param_types) {
std::vector<MappedType> type_args = std::move(param_types);
type_args.push_back(std::move(return_type));
@@ -144,7 +164,7 @@
.name = absl::StrCat(internal::kCcFuncValue, " ", cc_call_conv),
.type_args = std::move(cc_type_args),
};
- CcType cc_func_ptr_type = CcType{.name = std::string(internal::kCcPtr),
+ CcType cc_func_ref_type = CcType{.name = std::string(internal::kCcLValueRef),
.type_args = {cc_func_value_type}};
// Rust cannot express a function *value* type, only function pointer types.
@@ -155,15 +175,9 @@
if (lifetime.has_value())
rs_func_ptr_type.lifetime_args.push_back(*std::move(lifetime));
- // `fn() -> ()` is a *non-nullable* pointer in Rust. Since function pointers
- // in C++ *can* be null, we need to wrap Rust's function pointer in
- // `Option<...>`.
- RsType rs_option_type =
- RsType{.name = "Option", .type_args = {rs_func_ptr_type}};
-
return MappedType{
- .rs_type = std::move(rs_option_type),
- .cc_type = std::move(cc_func_ptr_type),
+ .rs_type = std::move(rs_func_ptr_type),
+ .cc_type = std::move(cc_func_ref_type),
};
}
diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h
index 16c6f93..dd83002 100644
--- a/rs_bindings_from_cc/ir.h
+++ b/rs_bindings_from_cc/ir.h
@@ -209,6 +209,11 @@
std::optional<LifetimeId> lifetime,
MappedType return_type,
std::vector<MappedType> param_types);
+ static MappedType FuncRef(absl::string_view cc_call_conv,
+ absl::string_view rs_abi,
+ std::optional<LifetimeId> lifetime,
+ MappedType return_type,
+ std::vector<MappedType> param_types);
bool IsVoid() const { return rs_type.name == "()"; }
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 8517c8c..291803c 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -2270,6 +2270,22 @@
}
#[test]
+ fn test_func_ref() -> Result<()> {
+ let ir = ir_from_cc(r#" int (&get_ref_to_func())(float, double); "#)?;
+ let rs_api = generate_rs_api(&ir)?;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ #[inline(always)]
+ pub fn get_ref_to_func() -> extern "C" fn (f32, f64) -> i32 {
+ unsafe { crate::detail::__rust_thunk___Z15get_ref_to_funcv() }
+ }
+ }
+ );
+ Ok(())
+ }
+
+ #[test]
fn test_func_ptr_with_non_static_lifetime() -> Result<()> {
let ir = ir_from_cc(
r#"
diff --git a/rs_bindings_from_cc/test/golden/lifetimes_rs_api.rs b/rs_bindings_from_cc/test/golden/lifetimes_rs_api.rs
index c4f7917..4b440b5 100644
--- a/rs_bindings_from_cc/test/golden/lifetimes_rs_api.rs
+++ b/rs_bindings_from_cc/test/golden/lifetimes_rs_api.rs
@@ -23,17 +23,17 @@
unsafe { crate::detail::__rust_thunk___Z18AddHookWithTypedefPFvvE(hook) }
}
-// rs_bindings_from_cc/test/golden/lifetimes.h;l=9
-// Error while generating bindings for item 'AddAnotherHook':
-// Parameter #0 is not supported: Unsupported type 'void (&)(void)'
+#[inline(always)]
+pub fn AddAnotherHook(__param_0: extern "C" fn()) {
+ unsafe { crate::detail::__rust_thunk___Z14AddAnotherHookRFvvE(__param_0) }
+}
-// rs_bindings_from_cc/test/golden/lifetimes.h;l=11
-// Error while generating bindings for item 'FunctionReference':
-// Unsupported type 'void (&)(void)'
+pub type FunctionReference = extern "C" fn();
-// rs_bindings_from_cc/test/golden/lifetimes.h;l=12
-// Error while generating bindings for item 'AddAnotherHookWithTypedef':
-// Parameter #0 is not supported: Unsupported type 'FunctionReference'
+#[inline(always)]
+pub fn AddAnotherHookWithTypedef(hook: extern "C" fn()) {
+ unsafe { crate::detail::__rust_thunk___Z25AddAnotherHookWithTypedefRFvvE(hook) }
+}
#[inline(always)]
pub unsafe fn ConsumeArray(pair: *mut i32) {
@@ -59,6 +59,10 @@
pub(crate) fn __rust_thunk___Z7AddHookPFvvE(__param_0: Option<extern "C" fn()>);
#[link_name = "_Z18AddHookWithTypedefPFvvE"]
pub(crate) fn __rust_thunk___Z18AddHookWithTypedefPFvvE(hook: Option<extern "C" fn()>);
+ #[link_name = "_Z14AddAnotherHookRFvvE"]
+ pub(crate) fn __rust_thunk___Z14AddAnotherHookRFvvE(__param_0: extern "C" fn());
+ #[link_name = "_Z25AddAnotherHookWithTypedefRFvvE"]
+ pub(crate) fn __rust_thunk___Z25AddAnotherHookWithTypedefRFvvE(hook: extern "C" fn());
#[link_name = "_Z12ConsumeArrayPi"]
pub(crate) fn __rust_thunk___Z12ConsumeArrayPi(pair: *mut i32);
#[link_name = "_Z23ConsumeArrayWithTypedefPi"]