Convert `RsTypeKind` creation into a Salsa query.
This is one of the more important things I wanted to do with this rewrite -- after this CL, we will no longer recompute the RsTypeKind all over the place, but just look it up or create on first use. (It's possible/probable(?) that there's no particular performance benefit unless/until we change `RsType` to be interned/hashconsed, though -- the point is that this is at least possible to optimize from here, I guess.)
Here I opted for a small/fast way to make `RsType` suitable as a cheaply-copyable query key: rather than holding `RsType` in an `Rc`, I made it relatively cheap to clone. (It is now five words -- we can cut it down to 4 words if `ItemId` were converted to be a nonzero number.) Slightly marginal, but since it's the only parameter to the query it should be fine really.
In the future, we may want to instead use something like `ArcIntern<RsType>` to make it very efficient to use in Salsa.
PiperOrigin-RevId: 459456664
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
index 9b96a95..2e4ddbe 100644
--- a/rs_bindings_from_cc/ir.rs
+++ b/rs_bindings_from_cc/ir.rs
@@ -124,8 +124,8 @@
#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
pub struct RsType {
pub name: Option<String>,
- pub lifetime_args: Vec<LifetimeId>,
- pub type_args: Vec<RsType>,
+ pub lifetime_args: Rc<[LifetimeId]>,
+ pub type_args: Rc<[RsType]>,
pub decl_id: Option<ItemId>,
}
diff --git a/rs_bindings_from_cc/ir_from_cc_test.rs b/rs_bindings_from_cc/ir_from_cc_test.rs
index 1db9e05..6b1253c 100644
--- a/rs_bindings_from_cc/ir_from_cc_test.rs
+++ b/rs_bindings_from_cc/ir_from_cc_test.rs
@@ -2669,15 +2669,15 @@
assert_eq!(lifetime_params.iter().map(|p| p.name.as_ref()).collect_vec(), vec!["a", "b"]);
let a_id = lifetime_params[0].id;
let b_id = lifetime_params[1].id;
- assert_eq!(func.return_type.rs_type.lifetime_args, vec![a_id]);
+ assert_eq!(&*func.return_type.rs_type.lifetime_args, &[a_id]);
assert_eq!(func.params[0].identifier, ir_id("__this"));
assert_eq!(func.params[0].type_.rs_type.name, Some("&mut".to_string()));
- assert_eq!(func.params[0].type_.rs_type.lifetime_args, vec![a_id]);
+ assert_eq!(&*func.params[0].type_.rs_type.lifetime_args, &[a_id]);
assert_eq!(func.params[1].identifier, ir_id("i"));
assert_eq!(func.params[1].type_.rs_type.name, Some("&mut".to_string()));
- assert_eq!(func.params[1].type_.rs_type.lifetime_args, vec![b_id]);
+ assert_eq!(&*func.params[1].type_.rs_type.lifetime_args, &[b_id]);
}
fn verify_elided_lifetimes_in_default_constructor(ir: &IR) {
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 22493f8..9ce0401 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -80,6 +80,8 @@
#[salsa::input]
fn ir(&self) -> Rc<IR>;
+ fn rs_type_kind(&self, rs_type: RsType) -> SalsaResult<RsTypeKind>;
+
fn generate_func(
&self,
func: Rc<Func>,
@@ -635,7 +637,7 @@
.params
.iter()
.map(|p| {
- RsTypeKind::new(&p.type_.rs_type, &ir).with_context(|| {
+ db.rs_type_kind(p.type_.rs_type.clone()).with_context(|| {
format!("Failed to process type of parameter {:?} on {:?}", p, func)
})
})
@@ -648,7 +650,8 @@
return Ok(None);
};
- let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, &ir)
+ let return_type_fragment = db
+ .rs_type_kind(func.return_type.rs_type.clone())
.with_context(|| format!("Failed to format return type for {:?}", func))?
.format_as_return_type_fragment();
let param_idents =
@@ -1004,8 +1007,8 @@
///
/// For non-Copy union fields, failing to use `ManuallyDrop<T>` would
/// additionally cause a compile-time error until https://github.com/rust-lang/rust/issues/55149 is stabilized.
-fn needs_manually_drop(ty: &ir::RsType, ir: &IR) -> Result<bool> {
- let ty_implements_copy = RsTypeKind::new(ty, ir)?.implements_copy();
+fn needs_manually_drop(db: &mut Database, ty: ir::RsType) -> Result<bool> {
+ let ty_implements_copy = db.rs_type_kind(ty)?.implements_copy();
Ok(!ty_implements_copy)
}
@@ -1213,14 +1216,15 @@
let field_type = match get_field_rs_type_for_layout(field) {
Err(_) => bit_padding(end - field.offset),
Ok(rs_type) => {
- let mut formatted = format_rs_type(&rs_type, &ir).with_context(|| {
+ let type_kind = db.rs_type_kind(rs_type.clone()).with_context(|| {
format!(
"Failed to format type for field {:?} on record {:?}",
field, record
)
})?;
+ let mut formatted = quote! {#type_kind};
if should_implement_drop(record) || record.is_union {
- if needs_manually_drop(rs_type, &ir)? {
+ if needs_manually_drop(db, rs_type.clone())? {
// TODO(b/212690698): Avoid (somewhat unergonomic) ManuallyDrop
// if we can ask Rust to preserve field destruction order if the
// destructor is the SpecialMemberFunc::NontrivialMembers
@@ -1356,7 +1360,7 @@
forward_declare::unsafe_define!(forward_declare::symbol!(#incomplete_symbol), #qualified_ident);
};
- let no_unique_address_accessors = cc_struct_no_unique_address_impl(record, &ir)?;
+ let no_unique_address_accessors = cc_struct_no_unique_address_impl(db, record)?;
let mut record_generated_items = record
.child_item_ids
.iter()
@@ -1481,9 +1485,9 @@
derives
}
-fn generate_enum(enum_: &Enum, ir: &IR) -> Result<TokenStream> {
+fn generate_enum(db: &mut Database, enum_: &Enum) -> Result<TokenStream> {
let name = make_rs_ident(&enum_.identifier.identifier);
- let underlying_type = format_rs_type(&enum_.underlying_type.rs_type, ir)?;
+ let underlying_type = db.rs_type_kind(enum_.underlying_type.rs_type.clone())?;
let enumerator_names =
enum_.enumerators.iter().map(|enumerator| make_rs_ident(&enumerator.identifier.identifier));
let enumerator_values = enum_.enumerators.iter().map(|enumerator| enumerator.value);
@@ -1507,10 +1511,11 @@
})
}
-fn generate_type_alias(type_alias: &TypeAlias, ir: &IR) -> Result<TokenStream> {
+fn generate_type_alias(db: &mut Database, type_alias: &TypeAlias) -> Result<TokenStream> {
let ident = make_rs_ident(&type_alias.identifier.identifier);
let doc_comment = generate_doc_comment(&type_alias.doc_comment);
- let underlying_type = format_rs_type(&type_alias.underlying_type.rs_type, ir)
+ let underlying_type = db
+ .rs_type_kind(type_alias.underlying_type.rs_type.clone())
.with_context(|| format!("Failed to format underlying type for {:?}", type_alias))?;
Ok(quote! {
#doc_comment
@@ -1682,7 +1687,7 @@
{
GeneratedItem::default()
} else {
- GeneratedItem { item: generate_enum(enum_, &ir)?, ..Default::default() }
+ GeneratedItem { item: generate_enum(db, enum_)?, ..Default::default() }
}
}
Item::TypeAlias(type_alias) => {
@@ -1691,7 +1696,7 @@
{
GeneratedItem::default()
} else {
- GeneratedItem { item: generate_type_alias(type_alias, &ir)?, ..Default::default() }
+ GeneratedItem { item: generate_type_alias(db, type_alias)?, ..Default::default() }
}
}
Item::UnsupportedItem(unsupported) => {
@@ -1876,7 +1881,7 @@
}
}
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq, Eq)]
enum RsTypeKind {
Pointer {
pointee: Rc<RsTypeKind>,
@@ -1930,10 +1935,13 @@
}
impl RsTypeKind {
- pub fn new(ty: &ir::RsType, ir: &IR) -> Result<Self> {
+ /// The implementation for the rs_type_kind query. Use that instead, as it
+ /// caches results.
+ fn query_impl(db: &dyn BindingsGenerator, ty: &ir::RsType) -> Result<Self> {
+ let ir = db.ir();
// The lambdas deduplicate code needed by multiple `match` branches.
- let get_type_args = || -> Result<Vec<RsTypeKind>> {
- ty.type_args.iter().map(|type_arg| RsTypeKind::new(type_arg, ir)).collect()
+ let get_type_args = || -> SalsaResult<Vec<RsTypeKind>> {
+ ty.type_args.iter().map(|type_arg| db.rs_type_kind(type_arg.clone())).collect()
};
let get_pointee = || -> Result<Rc<RsTypeKind>> {
if ty.type_args.len() != 1 {
@@ -1963,21 +1971,20 @@
incomplete_record: incomplete_record.clone(),
namespace_qualifier: generate_namespace_qualifier(
incomplete_record.id,
- ir,
+ &ir,
)?
.collect(),
- crate_ident: rs_imported_crate_name(&incomplete_record.owning_target, ir),
+ crate_ident: rs_imported_crate_name(&incomplete_record.owning_target, &ir),
},
- Item::Record(record) => RsTypeKind::new_record(record.clone(), ir)?,
+ Item::Record(record) => RsTypeKind::new_record(record.clone(), &ir)?,
Item::TypeAlias(type_alias) => RsTypeKind::TypeAlias {
type_alias: type_alias.clone(),
- namespace_qualifier: generate_namespace_qualifier(type_alias.id, ir)?
+ namespace_qualifier: generate_namespace_qualifier(type_alias.id, &ir)?
.collect(),
- crate_ident: rs_imported_crate_name(&type_alias.owning_target, ir),
- underlying_type: Rc::new(RsTypeKind::new(
- &type_alias.underlying_type.rs_type,
- ir,
- )?),
+ crate_ident: rs_imported_crate_name(&type_alias.owning_target, &ir),
+ underlying_type: Rc::new(
+ db.rs_type_kind(type_alias.underlying_type.rs_type.clone())?,
+ ),
},
other_item => bail!("Item does not define a type: {:?}", other_item),
}
@@ -2302,10 +2309,8 @@
}
}
-fn format_rs_type(ty: &ir::RsType, ir: &IR) -> Result<TokenStream> {
- RsTypeKind::new(ty, ir)
- .map(|kind| kind.to_token_stream())
- .with_context(|| format!("Failed to format Rust type {:?}", ty))
+fn rs_type_kind(db: &dyn BindingsGenerator, rs_type: ir::RsType) -> SalsaResult<RsTypeKind> {
+ Ok(RsTypeKind::query_impl(db, &rs_type)?)
}
fn cc_type_name_for_item(item: &ir::Item, ir: &IR) -> Result<TokenStream> {
@@ -2456,7 +2461,7 @@
}
// Returns the accessor functions for no_unique_address member variables.
-fn cc_struct_no_unique_address_impl(record: &Record, ir: &IR) -> Result<TokenStream> {
+fn cc_struct_no_unique_address_impl(db: &mut Database, record: &Record) -> Result<TokenStream> {
let mut fields = vec![];
let mut types = vec![];
for field in &record.fields {
@@ -2465,7 +2470,7 @@
}
// Can't use `get_field_rs_type_for_layout` here, because we want to dig into
// no_unique_address fields, despite laying them out as opaque blobs of bytes.
- if let Ok(rs_type) = field.type_.as_ref().map(|t| &t.rs_type) {
+ if let Ok(rs_type) = field.type_.as_ref().map(|t| t.rs_type.clone()) {
fields.push(make_rs_ident(
&field
.identifier
@@ -2473,9 +2478,9 @@
.expect("Unnamed fields can't be annotated with [[no_unique_address]]")
.identifier,
));
- types.push(format_rs_type(rs_type, ir).with_context(|| {
+ types.push(db.rs_type_kind(rs_type).with_context(|| {
format!("Failed to format type for field {:?} on record {:?}", field, record)
- })?)
+ })?);
}
}
@@ -2717,6 +2722,12 @@
super::generate_bindings_tokens(ir, "crubit/rs_bindings_support")
}
+ fn db_from_cc(cc_src: &str) -> Result<Database> {
+ let mut db = Database::default();
+ db.set_ir(ir_from_cc(cc_src)?);
+ Ok(db)
+ }
+
#[test]
fn test_disable_thread_safety_warnings() -> Result<()> {
let ir = ir_from_cc("inline void foo() {}")?;
@@ -5191,7 +5202,7 @@
fn test_thunk_ident_function() -> Result<()> {
let ir = ir_from_cc("inline int foo() {}")?;
let func = retrieve_func(&ir, "foo");
- assert_eq!(thunk_ident(func), make_rs_ident("__rust_thunk___Z3foov"));
+ assert_eq!(thunk_ident(&func), make_rs_ident("__rust_thunk___Z3foov"));
Ok(())
}
@@ -5436,9 +5447,11 @@
"LIFETIMES",
if test.lifetimes { "#pragma clang lifetime_elision" } else { "" },
);
- let ir = ir_from_cc(&cc_input)?;
+ let db = db_from_cc(&cc_input)?;
+ let ir = db.ir();
+
let f = retrieve_func(&ir, "func");
- let t = RsTypeKind::new(&f.params[0].type_.rs_type, &ir)?;
+ let t = db.rs_type_kind(f.params[0].type_.rs_type.clone())?;
let fmt = tokens_to_string(t.to_token_stream())?;
assert_eq!(test.rs, fmt, "Testing: {}", test_name);
@@ -5450,12 +5463,13 @@
#[test]
fn test_rs_type_kind_is_shared_ref_to_with_lifetimes() -> Result<()> {
- let ir = ir_from_cc(
+ let db = db_from_cc(
"#pragma clang lifetime_elision
struct SomeStruct {};
void foo(const SomeStruct& foo_param);
void bar(SomeStruct& bar_param);",
)?;
+ let ir = db.ir();
let record = ir.records().next().unwrap();
let foo_func = retrieve_func(&ir, "foo");
let bar_func = retrieve_func(&ir, "bar");
@@ -5464,7 +5478,7 @@
assert_eq!(foo_func.params.len(), 1);
let foo_param = &foo_func.params[0];
assert_eq!(&foo_param.identifier.identifier, "foo_param");
- let foo_type = RsTypeKind::new(&foo_param.type_.rs_type, &ir)?;
+ let foo_type = db.rs_type_kind(foo_param.type_.rs_type.clone())?;
assert!(foo_type.is_shared_ref_to(record));
assert!(matches!(foo_type, RsTypeKind::Reference { mutability: Mutability::Const, .. }));
@@ -5472,7 +5486,7 @@
assert_eq!(bar_func.params.len(), 1);
let bar_param = &bar_func.params[0];
assert_eq!(&bar_param.identifier.identifier, "bar_param");
- let bar_type = RsTypeKind::new(&bar_param.type_.rs_type, &ir)?;
+ let bar_type = db.rs_type_kind(bar_param.type_.rs_type.clone())?;
assert!(!bar_type.is_shared_ref_to(record));
assert!(matches!(bar_type, RsTypeKind::Reference { mutability: Mutability::Mut, .. }));
@@ -5481,10 +5495,11 @@
#[test]
fn test_rs_type_kind_is_shared_ref_to_without_lifetimes() -> Result<()> {
- let ir = ir_from_cc(
+ let db = db_from_cc(
"struct SomeStruct {};
void foo(const SomeStruct& foo_param);",
)?;
+ let ir = db.ir();
let record = ir.records().next().unwrap();
let foo_func = retrieve_func(&ir, "foo");
@@ -5492,7 +5507,7 @@
assert_eq!(foo_func.params.len(), 1);
let foo_param = &foo_func.params[0];
assert_eq!(&foo_param.identifier.identifier, "foo_param");
- let foo_type = RsTypeKind::new(&foo_param.type_.rs_type, &ir)?;
+ let foo_type = db.rs_type_kind(foo_param.type_.rs_type.clone())?;
assert!(!foo_type.is_shared_ref_to(record));
assert!(matches!(foo_type, RsTypeKind::Pointer { mutability: Mutability::Const, .. }));
@@ -5549,22 +5564,23 @@
#[test]
fn test_rs_type_kind_lifetimes() -> Result<()> {
- let ir = ir_from_cc(
+ let db = db_from_cc(
r#"
#pragma clang lifetime_elision
using TypeAlias = int&;
struct SomeStruct {};
void foo(int a, int& b, int&& c, int* d, int** e, TypeAlias f, SomeStruct g); "#,
)?;
+ let ir = db.ir();
let func = retrieve_func(&ir, "foo");
- let ret = RsTypeKind::new(&func.return_type.rs_type, &ir)?;
- let a = RsTypeKind::new(&func.params[0].type_.rs_type, &ir)?;
- let b = RsTypeKind::new(&func.params[1].type_.rs_type, &ir)?;
- let c = RsTypeKind::new(&func.params[2].type_.rs_type, &ir)?;
- let d = RsTypeKind::new(&func.params[3].type_.rs_type, &ir)?;
- let e = RsTypeKind::new(&func.params[4].type_.rs_type, &ir)?;
- let f = RsTypeKind::new(&func.params[5].type_.rs_type, &ir)?;
- let g = RsTypeKind::new(&func.params[6].type_.rs_type, &ir)?;
+ let ret = db.rs_type_kind(func.return_type.rs_type.clone())?;
+ let a = db.rs_type_kind(func.params[0].type_.rs_type.clone())?;
+ let b = db.rs_type_kind(func.params[1].type_.rs_type.clone())?;
+ let c = db.rs_type_kind(func.params[2].type_.rs_type.clone())?;
+ let d = db.rs_type_kind(func.params[3].type_.rs_type.clone())?;
+ let e = db.rs_type_kind(func.params[4].type_.rs_type.clone())?;
+ let f = db.rs_type_kind(func.params[5].type_.rs_type.clone())?;
+ let g = db.rs_type_kind(func.params[6].type_.rs_type.clone())?;
assert_eq!(0, ret.lifetimes().count()); // No lifetimes on `void`.
assert_eq!(0, a.lifetimes().count()); // No lifetimes on `int`.
@@ -5579,9 +5595,10 @@
#[test]
fn test_rs_type_kind_lifetimes_raw_ptr() -> Result<()> {
- let ir = ir_from_cc("void foo(int* a);")?;
+ let db = db_from_cc("void foo(int* a);")?;
+ let ir = db.ir();
let f = retrieve_func(&ir, "foo");
- let a = RsTypeKind::new(&f.params[0].type_.rs_type, &ir)?;
+ let a = db.rs_type_kind(f.params[0].type_.rs_type.clone())?;
assert_eq!(0, a.lifetimes().count()); // No lifetimes on `int*`.
Ok(())
}