Use salsa in src_code_gen.rs.
This only actually changes generate_func into a query for now -- notice that it's retrieved in two places, so this is actually saving work!
More importantly, this shows the shape of the change. Any given function can be migrated to a query by changing its parameter from a `&mut Database` to a `&dyn BindingsGenerator`.
A quick note on `PtrEq` -- this used to be possible without this; discussion at https://salsa.zulipchat.com/#narrow/stream/145099-general/topic/Return.20values.20that.20aren't.20comparable/near/286482700.
PiperOrigin-RevId: 458114855
diff --git a/rs_bindings_from_cc/BUILD b/rs_bindings_from_cc/BUILD
index 68c7495..3191b0a 100644
--- a/rs_bindings_from_cc/BUILD
+++ b/rs_bindings_from_cc/BUILD
@@ -355,6 +355,16 @@
],
)
+rust_library(
+ name = "salsa_utils",
+ srcs = ["salsa_utils.rs"],
+ deps = [
+ ":ir",
+ "@crate_index//:anyhow",
+ "@crate_index//:salsa",
+ ],
+)
+
cc_library(
name = "src_code_gen",
srcs = ["src_code_gen.cc"],
@@ -376,12 +386,14 @@
srcs = ["src_code_gen.rs"],
deps = [
":ir",
+ ":salsa_utils",
"//common:ffi_types",
"//common:token_stream_printer",
"@crate_index//:anyhow",
"@crate_index//:itertools",
"@crate_index//:proc-macro2",
"@crate_index//:quote",
+ "@crate_index//:salsa",
"@crate_index//:serde_json",
"@crate_index//:syn",
],
diff --git a/rs_bindings_from_cc/salsa_utils.rs b/rs_bindings_from_cc/salsa_utils.rs
new file mode 100644
index 0000000..06d75e2
--- /dev/null
+++ b/rs_bindings_from_cc/salsa_utils.rs
@@ -0,0 +1,89 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//! Copyable, equality-comparable types for Salsa.
+//!
+//! TODO(jeanpierreda): give this module a better name.
+
+#![feature(backtrace)]
+
+use std::ops::Deref;
+use std::sync::Arc;
+
+/// A wrapper for a smart pointer, which implements `Eq` as pointer equality.
+///
+/// This was directly inspired by Chalk's `ArcEq`, which does the same.
+/// However, unlike Chalk, `PtrEq` does not implement `Deref`: that would
+/// normally imply that it has the same behavior as the underlying
+/// pointee, and it obviously does not, as it implements `Eq` even if the
+/// pointee doesn't.
+///
+/// Instead, to access the underlying value, use `.as_ref()`.
+#[derive(Debug, Clone)]
+#[repr(transparent)]
+pub struct PtrEq<T: Deref>(pub T);
+
+impl<T: Deref> PartialEq<PtrEq<T>> for PtrEq<T> {
+ fn eq(&self, other: &Self) -> bool {
+ std::ptr::eq(&*self.0, &*other.0)
+ }
+}
+
+impl<T: Deref> Eq for PtrEq<T> {}
+
+impl<T: Deref> PtrEq<T> {
+ pub fn as_ref(&self) -> &T::Target {
+ &*self.0
+ }
+}
+
+/// A clonable, equality-comparable error which is interconvertible with
+/// `anyhow::Error`.
+///
+/// Two errors are equal if they are identical (i.e. they both have a common
+/// cloned-from ancestor.)
+///
+/// Salsa queries should return `Result<Rc<T>, SalsaError>`, and not
+/// `Rc<Result<T, anyhow::Error>>`. Because `anyhow::Error` cannot be cloned,
+/// `Rc<Result<T, anyhow::Error>>` is very nearly useless, as one cannot create
+/// a new `Rc<Result<U, anyhow::Error>>` containing the same error.
+/// Error propagation with cached errors requires that the underlying error type
+/// be copyable.
+///
+/// (Implementation note: SalsaError itself uses `Arc`, not `Rc`, because
+/// `anyhow::Error` requires `Send`+`Sync`.)
+#[derive(Clone, Debug)]
+pub struct SalsaError(Arc<dyn std::error::Error + Send + Sync + 'static>);
+
+impl PartialEq for SalsaError {
+ fn eq(&self, other: &Self) -> bool {
+ std::ptr::eq(&*self.0, &*other.0)
+ }
+}
+
+impl Eq for SalsaError {}
+
+impl std::fmt::Display for SalsaError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ std::fmt::Display::fmt(&*self.0, f)
+ }
+}
+
+impl std::error::Error for SalsaError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ self.0.source()
+ }
+ fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
+ self.0.backtrace()
+ }
+}
+
+impl From<anyhow::Error> for SalsaError {
+ fn from(e: anyhow::Error) -> Self {
+ let e: Box<dyn std::error::Error + Send + Sync + 'static> = e.into();
+ SalsaError(e.into())
+ }
+}
+
+pub type SalsaResult<T> = Result<T, SalsaError>;
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 2d39b12..644bd4d 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -8,6 +8,7 @@
use itertools::Itertools;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::{format_ident, quote, ToTokens};
+use salsa_utils::{PtrEq, SalsaResult};
use std::collections::{BTreeSet, HashSet};
use std::ffi::{OsStr, OsString};
use std::iter::{self, Iterator};
@@ -74,6 +75,25 @@
.unwrap_or_else(|_| process::abort())
}
+#[salsa::query_group(BindingsGeneratorStorage)]
+trait BindingsGenerator {
+ #[salsa::input]
+ fn ir(&self) -> Rc<IR>;
+
+ fn generate_func(
+ &self,
+ func: Rc<Func>,
+ ) -> SalsaResult<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>>;
+}
+
+#[salsa::database(BindingsGeneratorStorage)]
+#[derive(Default)]
+struct Database {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for Database {}
+
/// Source code for generated bindings.
struct Bindings {
// Rust source code.
@@ -597,30 +617,42 @@
/// destructor might be mapped to no `Drop` impl at all.)
/// * `Ok((rs_api, rs_thunk, function_id))`: The Rust function definition,
/// thunk FFI definition, and function ID.
-fn generate_func(func: &Func, ir: &IR) -> Result<Option<(RsSnippet, RsSnippet, FunctionId)>> {
+fn generate_func(
+ db: &dyn BindingsGenerator,
+ func: Rc<Func>,
+) -> SalsaResult<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>> {
+ Ok(generate_func_impl(db, &func)?)
+}
+
+fn generate_func_impl(
+ db: &dyn BindingsGenerator,
+ func: &Func,
+) -> Result<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>> {
+ let ir = db.ir();
let param_types = func
.params
.iter()
.map(|p| {
- RsTypeKind::new(&p.type_.rs_type, ir).with_context(|| {
+ RsTypeKind::new(&p.type_.rs_type, &ir).with_context(|| {
format!("Failed to process type of parameter {:?} on {:?}", p, func)
})
})
.collect::<Result<Vec<_>>>()?;
- let (func_name, mut impl_kind) = if let Some(values) = api_func_shape(func, ir, ¶m_types)? {
+ let (func_name, mut impl_kind) = if let Some(values) = api_func_shape(&func, &ir, ¶m_types)?
+ {
values
} else {
return Ok(None);
};
- let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, ir)
+ let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, &ir)
.with_context(|| format!("Failed to format return type for {:?}", func))?
.format_as_return_type_fragment();
let param_idents =
func.params.iter().map(|p| make_rs_ident(&p.identifier.identifier)).collect_vec();
- let thunk = generate_func_thunk(func, ¶m_idents, ¶m_types, &return_type_fragment)?;
+ let thunk = generate_func_thunk(&func, ¶m_idents, ¶m_types, &return_type_fragment)?;
let api_func_def = {
let mut return_type_fragment = return_type_fragment;
@@ -699,7 +731,7 @@
// MaybeUninit<T> in Pin if T is !Unpin. It should understand
// 'structural pinning', so that we do not need into_inner_unchecked()
// here.
- let thunk_ident = thunk_ident(func);
+ let thunk_ident = thunk_ident(&func);
let func_body = match &impl_kind {
ImplKind::Trait { trait_name: TraitName::CtorNew(..), .. } => {
let thunk_vars = format_tuple_except_singleton(&thunk_args);
@@ -842,7 +874,11 @@
}
}
- Ok(Some((RsSnippet { features, tokens: api_func }, thunk.into(), function_id)))
+ Ok(Some(PtrEq(Rc::new((
+ RsSnippet { features, tokens: api_func },
+ thunk.into(),
+ Rc::new(function_id),
+ )))))
}
fn generate_func_thunk(
@@ -873,7 +909,7 @@
})?);
}
- let thunk_ident = thunk_ident(func);
+ let thunk_ident = thunk_ident(&func);
let lifetimes = func.lifetime_params.iter();
let generic_params = format_generic_params(lifetimes);
let param_types = self_param.into_iter().chain(param_types.map(|t| quote! {#t}));
@@ -1034,12 +1070,13 @@
/// Generates Rust source code for a given `Record` and associated assertions as
/// a tuple.
fn generate_record(
+ db: &mut Database,
record: &Rc<Record>,
- ir: &IR,
- overloaded_funcs: &HashSet<FunctionId>,
+ overloaded_funcs: &HashSet<Rc<FunctionId>>,
) -> Result<GeneratedItem> {
+ let ir = db.ir();
let ident = make_rs_ident(&record.rs_name);
- let namespace_qualifier = generate_namespace_qualifier(record.id, ir)?;
+ let namespace_qualifier = generate_namespace_qualifier(record.id, &ir)?;
let qualified_ident = {
quote! { crate:: #(#namespace_qualifier::)* #ident }
};
@@ -1178,14 +1215,14 @@
let field_type = match get_field_rs_type_for_layout(field) {
Err(_) => bit_padding(end - field.offset),
Ok(rs_type) => {
- let mut formatted = format_rs_type(&rs_type, ir).with_context(|| {
+ let mut formatted = format_rs_type(&rs_type, &ir).with_context(|| {
format!(
"Failed to format type for field {:?} on record {:?}",
field, record
)
})?;
if should_implement_drop(record) || record.is_union {
- if needs_manually_drop(rs_type, ir)? {
+ if needs_manually_drop(rs_type, &ir)? {
// TODO(b/212690698): Avoid (somewhat unergonomic) ManuallyDrop
// if we can ask Rust to preserve field destruction order if the
// destructor is the SpecialMemberFunc::NontrivialMembers
@@ -1321,17 +1358,17 @@
forward_declare::unsafe_define!(forward_declare::symbol!(#incomplete_symbol), #qualified_ident);
};
- let no_unique_address_accessors = cc_struct_no_unique_address_impl(record, ir)?;
+ let no_unique_address_accessors = cc_struct_no_unique_address_impl(record, &ir)?;
let mut record_generated_items = record
.child_item_ids
.iter()
.map(|id| {
let item = ir.find_decl(*id)?;
- generate_item(item, &ir, overloaded_funcs)
+ generate_item(db, item, overloaded_funcs)
})
.collect::<Result<Vec<_>>>()?;
- record_generated_items.push(cc_struct_upcast_impl(record, ir)?);
+ record_generated_items.push(cc_struct_upcast_impl(record, &ir)?);
let mut items = vec![];
let mut thunks_from_record_items = vec![];
@@ -1371,7 +1408,7 @@
};
let record_trait_assertions = {
- let record_type_name = RsTypeKind::new_record(record.clone(), ir)?.to_token_stream();
+ let record_type_name = RsTypeKind::new_record(record.clone(), &ir)?.to_token_stream();
let mut assertions: Vec<TokenStream> = vec![];
let mut add_assertion = |assert_impl_macro: TokenStream, trait_name: TokenStream| {
assertions.push(quote! {
@@ -1508,10 +1545,11 @@
}
fn generate_namespace(
+ db: &mut Database,
namespace: &Namespace,
- ir: &IR,
- overloaded_funcs: &HashSet<FunctionId>,
+ overloaded_funcs: &HashSet<Rc<FunctionId>>,
) -> Result<GeneratedItem> {
+ let ir = db.ir();
let mut items = vec![];
let mut thunks = vec![];
let mut assertions = vec![];
@@ -1520,7 +1558,7 @@
for item_id in namespace.child_item_ids.iter() {
let item = ir.find_decl(*item_id)?;
- let generated = generate_item(item, ir, &overloaded_funcs)?;
+ let generated = generate_item(db, item, &overloaded_funcs)?;
items.push(generated.item);
if !generated.thunks.is_empty() {
thunks.push(generated.thunks);
@@ -1591,31 +1629,35 @@
}
fn generate_item(
+ db: &mut Database,
item: &Item,
- ir: &IR,
- overloaded_funcs: &HashSet<FunctionId>,
+ overloaded_funcs: &HashSet<Rc<FunctionId>>,
) -> Result<GeneratedItem> {
+ let ir = db.ir();
let generated_item = match item {
- Item::Func(func) => match generate_func(func, ir) {
+ Item::Func(func) => match db.generate_func(func.clone()) {
Err(e) => GeneratedItem {
- item: generate_unsupported(&make_unsupported_fn(func, ir, format!("{e}"))?)?,
+ item: generate_unsupported(&make_unsupported_fn(func, &ir, format!("{e}"))?)?,
..Default::default()
},
Ok(None) => GeneratedItem::default(),
- Ok(Some((api_func, thunk, function_id))) => {
- if overloaded_funcs.contains(&function_id) {
+ Ok(Some(f)) => {
+ let (api_func, thunk, function_id) = f.as_ref();
+ if overloaded_funcs.contains(function_id) {
GeneratedItem {
item: generate_unsupported(&make_unsupported_fn(
func,
- ir,
+ &ir,
"Cannot generate bindings for overloaded function",
)?)?,
..Default::default()
}
} else {
+ // TODO(b/236687702): Use Rc for these, or else split this into a non-query
+ // and only use the query for Function IDs.
GeneratedItem {
- item: api_func.tokens,
- thunks: thunk.tokens,
+ item: api_func.tokens.clone(),
+ thunks: thunk.tokens.clone(),
features: api_func.features.union(&thunk.features).cloned().collect(),
..Default::default()
}
@@ -1640,7 +1682,7 @@
{
GeneratedItem::default()
} else {
- generate_record(record, ir, overloaded_funcs)?
+ generate_record(db, record, overloaded_funcs)?
}
}
Item::Enum(enum_) => {
@@ -1649,7 +1691,7 @@
{
GeneratedItem::default()
} else {
- GeneratedItem { item: generate_enum(enum_, ir)?, ..Default::default() }
+ GeneratedItem { item: generate_enum(enum_, &ir)?, ..Default::default() }
}
}
Item::TypeAlias(type_alias) => {
@@ -1658,7 +1700,7 @@
{
GeneratedItem::default()
} else {
- GeneratedItem { item: generate_type_alias(type_alias, ir)?, ..Default::default() }
+ GeneratedItem { item: generate_type_alias(type_alias, &ir)?, ..Default::default() }
}
}
Item::UnsupportedItem(unsupported) => {
@@ -1667,7 +1709,7 @@
Item::Comment(comment) => {
GeneratedItem { item: generate_comment(comment)?, ..Default::default() }
}
- Item::Namespace(namespace) => generate_namespace(namespace, ir, overloaded_funcs)?,
+ Item::Namespace(namespace) => generate_namespace(db, namespace, overloaded_funcs)?,
};
Ok(generated_item)
@@ -1676,6 +1718,9 @@
// Returns the Rust code implementing bindings, plus any auxiliary C++ code
// needed to support it.
fn generate_bindings_tokens(ir: Rc<IR>, crubit_support_path: &str) -> Result<BindingsTokens> {
+ let mut db = Database::default();
+ db.set_ir(ir.clone());
+
let mut items = vec![];
let mut thunks = vec![];
let mut thunk_impls = vec![generate_rs_api_impl(&ir, crubit_support_path)?];
@@ -1703,16 +1748,17 @@
let mut seen_funcs = HashSet::new();
let mut overloaded_funcs = HashSet::new();
for func in ir.functions() {
- if let Ok(Some((.., function_id))) = generate_func(func, &ir) {
+ if let Ok(Some(f)) = db.generate_func(func.clone()) {
+ let (.., function_id) = f.as_ref();
if !seen_funcs.insert(function_id.clone()) {
- overloaded_funcs.insert(function_id);
+ overloaded_funcs.insert(function_id.clone());
}
}
}
for top_level_item_id in ir.top_level_item_ids() {
let item = ir.find_decl(*top_level_item_id)?;
- let generated = generate_item(item, &ir, &overloaded_funcs)?;
+ let generated = generate_item(&mut db, item, &overloaded_funcs)?;
items.push(generated.item);
if !generated.thunks.is_empty() {
thunks.push(generated.thunks);