Use salsa in src_code_gen.rs.

This only actually changes generate_func into a query for now -- notice that it's retrieved in two places, so this is actually saving work!

More importantly, this shows the shape of the change. Any given function can be migrated to a query by changing its parameter from a `&mut Database` to a `&dyn BindingsGenerator`.

A quick note on `PtrEq` -- this used to be possible without this; discussion at https://salsa.zulipchat.com/#narrow/stream/145099-general/topic/Return.20values.20that.20aren't.20comparable/near/286482700.

PiperOrigin-RevId: 458114855
diff --git a/rs_bindings_from_cc/BUILD b/rs_bindings_from_cc/BUILD
index 68c7495..3191b0a 100644
--- a/rs_bindings_from_cc/BUILD
+++ b/rs_bindings_from_cc/BUILD
@@ -355,6 +355,16 @@
     ],
 )
 
+rust_library(
+    name = "salsa_utils",
+    srcs = ["salsa_utils.rs"],
+    deps = [
+        ":ir",
+        "@crate_index//:anyhow",
+        "@crate_index//:salsa",
+    ],
+)
+
 cc_library(
     name = "src_code_gen",
     srcs = ["src_code_gen.cc"],
@@ -376,12 +386,14 @@
     srcs = ["src_code_gen.rs"],
     deps = [
         ":ir",
+        ":salsa_utils",
         "//common:ffi_types",
         "//common:token_stream_printer",
         "@crate_index//:anyhow",
         "@crate_index//:itertools",
         "@crate_index//:proc-macro2",
         "@crate_index//:quote",
+        "@crate_index//:salsa",
         "@crate_index//:serde_json",
         "@crate_index//:syn",
     ],
diff --git a/rs_bindings_from_cc/salsa_utils.rs b/rs_bindings_from_cc/salsa_utils.rs
new file mode 100644
index 0000000..06d75e2
--- /dev/null
+++ b/rs_bindings_from_cc/salsa_utils.rs
@@ -0,0 +1,89 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//! Copyable, equality-comparable types for Salsa.
+//!
+//! TODO(jeanpierreda): give this module a better name.
+
+#![feature(backtrace)]
+
+use std::ops::Deref;
+use std::sync::Arc;
+
+/// A wrapper for a smart pointer, which implements `Eq` as pointer equality.
+///
+/// This was directly inspired by Chalk's `ArcEq`, which does the same.
+/// However, unlike Chalk, `PtrEq` does not implement `Deref`: that would
+/// normally imply that it has the same behavior as the underlying
+/// pointee, and it obviously does not, as it implements `Eq` even if the
+/// pointee doesn't.
+///
+/// Instead, to access the underlying value, use `.as_ref()`.
+#[derive(Debug, Clone)]
+#[repr(transparent)]
+pub struct PtrEq<T: Deref>(pub T);
+
+impl<T: Deref> PartialEq<PtrEq<T>> for PtrEq<T> {
+    fn eq(&self, other: &Self) -> bool {
+        std::ptr::eq(&*self.0, &*other.0)
+    }
+}
+
+impl<T: Deref> Eq for PtrEq<T> {}
+
+impl<T: Deref> PtrEq<T> {
+    pub fn as_ref(&self) -> &T::Target {
+        &*self.0
+    }
+}
+
+/// A clonable, equality-comparable error which is interconvertible with
+/// `anyhow::Error`.
+///
+/// Two errors are equal if they are identical (i.e. they both have a common
+/// cloned-from ancestor.)
+///
+/// Salsa queries should return `Result<Rc<T>, SalsaError>`, and not
+/// `Rc<Result<T, anyhow::Error>>`. Because `anyhow::Error` cannot be cloned,
+/// `Rc<Result<T, anyhow::Error>>` is very nearly useless, as one cannot create
+/// a new `Rc<Result<U, anyhow::Error>>` containing the same error.
+/// Error propagation with cached errors requires that the underlying error type
+/// be copyable.
+///
+/// (Implementation note: SalsaError itself uses `Arc`, not `Rc`, because
+/// `anyhow::Error` requires `Send`+`Sync`.)
+#[derive(Clone, Debug)]
+pub struct SalsaError(Arc<dyn std::error::Error + Send + Sync + 'static>);
+
+impl PartialEq for SalsaError {
+    fn eq(&self, other: &Self) -> bool {
+        std::ptr::eq(&*self.0, &*other.0)
+    }
+}
+
+impl Eq for SalsaError {}
+
+impl std::fmt::Display for SalsaError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+        std::fmt::Display::fmt(&*self.0, f)
+    }
+}
+
+impl std::error::Error for SalsaError {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        self.0.source()
+    }
+    fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
+        self.0.backtrace()
+    }
+}
+
+impl From<anyhow::Error> for SalsaError {
+    fn from(e: anyhow::Error) -> Self {
+        let e: Box<dyn std::error::Error + Send + Sync + 'static> = e.into();
+        SalsaError(e.into())
+    }
+}
+
+pub type SalsaResult<T> = Result<T, SalsaError>;
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 2d39b12..644bd4d 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -8,6 +8,7 @@
 use itertools::Itertools;
 use proc_macro2::{Ident, Literal, TokenStream};
 use quote::{format_ident, quote, ToTokens};
+use salsa_utils::{PtrEq, SalsaResult};
 use std::collections::{BTreeSet, HashSet};
 use std::ffi::{OsStr, OsString};
 use std::iter::{self, Iterator};
@@ -74,6 +75,25 @@
     .unwrap_or_else(|_| process::abort())
 }
 
+#[salsa::query_group(BindingsGeneratorStorage)]
+trait BindingsGenerator {
+    #[salsa::input]
+    fn ir(&self) -> Rc<IR>;
+
+    fn generate_func(
+        &self,
+        func: Rc<Func>,
+    ) -> SalsaResult<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>>;
+}
+
+#[salsa::database(BindingsGeneratorStorage)]
+#[derive(Default)]
+struct Database {
+    storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for Database {}
+
 /// Source code for generated bindings.
 struct Bindings {
     // Rust source code.
@@ -597,30 +617,42 @@
 ///    destructor might be mapped to no `Drop` impl at all.)
 ///  * `Ok((rs_api, rs_thunk, function_id))`: The Rust function definition,
 ///    thunk FFI definition, and function ID.
-fn generate_func(func: &Func, ir: &IR) -> Result<Option<(RsSnippet, RsSnippet, FunctionId)>> {
+fn generate_func(
+    db: &dyn BindingsGenerator,
+    func: Rc<Func>,
+) -> SalsaResult<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>> {
+    Ok(generate_func_impl(db, &func)?)
+}
+
+fn generate_func_impl(
+    db: &dyn BindingsGenerator,
+    func: &Func,
+) -> Result<Option<PtrEq<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>> {
+    let ir = db.ir();
     let param_types = func
         .params
         .iter()
         .map(|p| {
-            RsTypeKind::new(&p.type_.rs_type, ir).with_context(|| {
+            RsTypeKind::new(&p.type_.rs_type, &ir).with_context(|| {
                 format!("Failed to process type of parameter {:?} on {:?}", p, func)
             })
         })
         .collect::<Result<Vec<_>>>()?;
 
-    let (func_name, mut impl_kind) = if let Some(values) = api_func_shape(func, ir, &param_types)? {
+    let (func_name, mut impl_kind) = if let Some(values) = api_func_shape(&func, &ir, &param_types)?
+    {
         values
     } else {
         return Ok(None);
     };
 
-    let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, ir)
+    let return_type_fragment = RsTypeKind::new(&func.return_type.rs_type, &ir)
         .with_context(|| format!("Failed to format return type for {:?}", func))?
         .format_as_return_type_fragment();
     let param_idents =
         func.params.iter().map(|p| make_rs_ident(&p.identifier.identifier)).collect_vec();
 
-    let thunk = generate_func_thunk(func, &param_idents, &param_types, &return_type_fragment)?;
+    let thunk = generate_func_thunk(&func, &param_idents, &param_types, &return_type_fragment)?;
 
     let api_func_def = {
         let mut return_type_fragment = return_type_fragment;
@@ -699,7 +731,7 @@
         // MaybeUninit<T> in Pin if T is !Unpin. It should understand
         // 'structural pinning', so that we do not need into_inner_unchecked()
         // here.
-        let thunk_ident = thunk_ident(func);
+        let thunk_ident = thunk_ident(&func);
         let func_body = match &impl_kind {
             ImplKind::Trait { trait_name: TraitName::CtorNew(..), .. } => {
                 let thunk_vars = format_tuple_except_singleton(&thunk_args);
@@ -842,7 +874,11 @@
         }
     }
 
-    Ok(Some((RsSnippet { features, tokens: api_func }, thunk.into(), function_id)))
+    Ok(Some(PtrEq(Rc::new((
+        RsSnippet { features, tokens: api_func },
+        thunk.into(),
+        Rc::new(function_id),
+    )))))
 }
 
 fn generate_func_thunk(
@@ -873,7 +909,7 @@
         })?);
     }
 
-    let thunk_ident = thunk_ident(func);
+    let thunk_ident = thunk_ident(&func);
     let lifetimes = func.lifetime_params.iter();
     let generic_params = format_generic_params(lifetimes);
     let param_types = self_param.into_iter().chain(param_types.map(|t| quote! {#t}));
@@ -1034,12 +1070,13 @@
 /// Generates Rust source code for a given `Record` and associated assertions as
 /// a tuple.
 fn generate_record(
+    db: &mut Database,
     record: &Rc<Record>,
-    ir: &IR,
-    overloaded_funcs: &HashSet<FunctionId>,
+    overloaded_funcs: &HashSet<Rc<FunctionId>>,
 ) -> Result<GeneratedItem> {
+    let ir = db.ir();
     let ident = make_rs_ident(&record.rs_name);
-    let namespace_qualifier = generate_namespace_qualifier(record.id, ir)?;
+    let namespace_qualifier = generate_namespace_qualifier(record.id, &ir)?;
     let qualified_ident = {
         quote! { crate:: #(#namespace_qualifier::)* #ident }
     };
@@ -1178,14 +1215,14 @@
             let field_type = match get_field_rs_type_for_layout(field) {
                 Err(_) => bit_padding(end - field.offset),
                 Ok(rs_type) => {
-                    let mut formatted = format_rs_type(&rs_type, ir).with_context(|| {
+                    let mut formatted = format_rs_type(&rs_type, &ir).with_context(|| {
                         format!(
                             "Failed to format type for field {:?} on record {:?}",
                             field, record
                         )
                     })?;
                     if should_implement_drop(record) || record.is_union {
-                        if needs_manually_drop(rs_type, ir)? {
+                        if needs_manually_drop(rs_type, &ir)? {
                             // TODO(b/212690698): Avoid (somewhat unergonomic) ManuallyDrop
                             // if we can ask Rust to preserve field destruction order if the
                             // destructor is the SpecialMemberFunc::NontrivialMembers
@@ -1321,17 +1358,17 @@
         forward_declare::unsafe_define!(forward_declare::symbol!(#incomplete_symbol), #qualified_ident);
     };
 
-    let no_unique_address_accessors = cc_struct_no_unique_address_impl(record, ir)?;
+    let no_unique_address_accessors = cc_struct_no_unique_address_impl(record, &ir)?;
     let mut record_generated_items = record
         .child_item_ids
         .iter()
         .map(|id| {
             let item = ir.find_decl(*id)?;
-            generate_item(item, &ir, overloaded_funcs)
+            generate_item(db, item, overloaded_funcs)
         })
         .collect::<Result<Vec<_>>>()?;
 
-    record_generated_items.push(cc_struct_upcast_impl(record, ir)?);
+    record_generated_items.push(cc_struct_upcast_impl(record, &ir)?);
 
     let mut items = vec![];
     let mut thunks_from_record_items = vec![];
@@ -1371,7 +1408,7 @@
     };
 
     let record_trait_assertions = {
-        let record_type_name = RsTypeKind::new_record(record.clone(), ir)?.to_token_stream();
+        let record_type_name = RsTypeKind::new_record(record.clone(), &ir)?.to_token_stream();
         let mut assertions: Vec<TokenStream> = vec![];
         let mut add_assertion = |assert_impl_macro: TokenStream, trait_name: TokenStream| {
             assertions.push(quote! {
@@ -1508,10 +1545,11 @@
 }
 
 fn generate_namespace(
+    db: &mut Database,
     namespace: &Namespace,
-    ir: &IR,
-    overloaded_funcs: &HashSet<FunctionId>,
+    overloaded_funcs: &HashSet<Rc<FunctionId>>,
 ) -> Result<GeneratedItem> {
+    let ir = db.ir();
     let mut items = vec![];
     let mut thunks = vec![];
     let mut assertions = vec![];
@@ -1520,7 +1558,7 @@
 
     for item_id in namespace.child_item_ids.iter() {
         let item = ir.find_decl(*item_id)?;
-        let generated = generate_item(item, ir, &overloaded_funcs)?;
+        let generated = generate_item(db, item, &overloaded_funcs)?;
         items.push(generated.item);
         if !generated.thunks.is_empty() {
             thunks.push(generated.thunks);
@@ -1591,31 +1629,35 @@
 }
 
 fn generate_item(
+    db: &mut Database,
     item: &Item,
-    ir: &IR,
-    overloaded_funcs: &HashSet<FunctionId>,
+    overloaded_funcs: &HashSet<Rc<FunctionId>>,
 ) -> Result<GeneratedItem> {
+    let ir = db.ir();
     let generated_item = match item {
-        Item::Func(func) => match generate_func(func, ir) {
+        Item::Func(func) => match db.generate_func(func.clone()) {
             Err(e) => GeneratedItem {
-                item: generate_unsupported(&make_unsupported_fn(func, ir, format!("{e}"))?)?,
+                item: generate_unsupported(&make_unsupported_fn(func, &ir, format!("{e}"))?)?,
                 ..Default::default()
             },
             Ok(None) => GeneratedItem::default(),
-            Ok(Some((api_func, thunk, function_id))) => {
-                if overloaded_funcs.contains(&function_id) {
+            Ok(Some(f)) => {
+                let (api_func, thunk, function_id) = f.as_ref();
+                if overloaded_funcs.contains(function_id) {
                     GeneratedItem {
                         item: generate_unsupported(&make_unsupported_fn(
                             func,
-                            ir,
+                            &ir,
                             "Cannot generate bindings for overloaded function",
                         )?)?,
                         ..Default::default()
                     }
                 } else {
+                    // TODO(b/236687702): Use Rc for these, or else split this into a non-query
+                    // and only use the query for Function IDs.
                     GeneratedItem {
-                        item: api_func.tokens,
-                        thunks: thunk.tokens,
+                        item: api_func.tokens.clone(),
+                        thunks: thunk.tokens.clone(),
                         features: api_func.features.union(&thunk.features).cloned().collect(),
                         ..Default::default()
                     }
@@ -1640,7 +1682,7 @@
             {
                 GeneratedItem::default()
             } else {
-                generate_record(record, ir, overloaded_funcs)?
+                generate_record(db, record, overloaded_funcs)?
             }
         }
         Item::Enum(enum_) => {
@@ -1649,7 +1691,7 @@
             {
                 GeneratedItem::default()
             } else {
-                GeneratedItem { item: generate_enum(enum_, ir)?, ..Default::default() }
+                GeneratedItem { item: generate_enum(enum_, &ir)?, ..Default::default() }
             }
         }
         Item::TypeAlias(type_alias) => {
@@ -1658,7 +1700,7 @@
             {
                 GeneratedItem::default()
             } else {
-                GeneratedItem { item: generate_type_alias(type_alias, ir)?, ..Default::default() }
+                GeneratedItem { item: generate_type_alias(type_alias, &ir)?, ..Default::default() }
             }
         }
         Item::UnsupportedItem(unsupported) => {
@@ -1667,7 +1709,7 @@
         Item::Comment(comment) => {
             GeneratedItem { item: generate_comment(comment)?, ..Default::default() }
         }
-        Item::Namespace(namespace) => generate_namespace(namespace, ir, overloaded_funcs)?,
+        Item::Namespace(namespace) => generate_namespace(db, namespace, overloaded_funcs)?,
     };
 
     Ok(generated_item)
@@ -1676,6 +1718,9 @@
 // Returns the Rust code implementing bindings, plus any auxiliary C++ code
 // needed to support it.
 fn generate_bindings_tokens(ir: Rc<IR>, crubit_support_path: &str) -> Result<BindingsTokens> {
+    let mut db = Database::default();
+    db.set_ir(ir.clone());
+
     let mut items = vec![];
     let mut thunks = vec![];
     let mut thunk_impls = vec![generate_rs_api_impl(&ir, crubit_support_path)?];
@@ -1703,16 +1748,17 @@
     let mut seen_funcs = HashSet::new();
     let mut overloaded_funcs = HashSet::new();
     for func in ir.functions() {
-        if let Ok(Some((.., function_id))) = generate_func(func, &ir) {
+        if let Ok(Some(f)) = db.generate_func(func.clone()) {
+            let (.., function_id) = f.as_ref();
             if !seen_funcs.insert(function_id.clone()) {
-                overloaded_funcs.insert(function_id);
+                overloaded_funcs.insert(function_id.clone());
             }
         }
     }
 
     for top_level_item_id in ir.top_level_item_ids() {
         let item = ir.find_decl(*top_level_item_id)?;
-        let generated = generate_item(item, &ir, &overloaded_funcs)?;
+        let generated = generate_item(&mut db, item, &overloaded_funcs)?;
         items.push(generated.item);
         if !generated.thunks.is_empty() {
             thunks.push(generated.thunks);