Replace `tcx.non_blanket_impls_for_ty(...)` with `tcx.associated_items(...)`.
`non_blanket_impls_for_ty` doesn't work well for:
1. Traits with non-mandatory methods (like `Clone::clone_from` or
`PartialEq::ne` where the default implementation is provided by the
trait definition). Desire to support `Clone` in a follow-up CL is
the main motivation for this CL.
2. Blanket impls.
Because of the above this CL changes how `cc_bindings_from_rs` generates
C-ABI thunks that call into a type/`Self`-specific implementation of a
given trait method. Before this CL, the code would use
`non_blanket_impls_for_ty` to iterate over `def_id`s of
type/`Self`-*specialized* method `impl`s. After this CL, the code
iterates over all trait methods reported by `associated_items` which
helps by also covering methods with no specialized `impl`. OTOH this
means that we now iterate over `Self`-*generic* methods and therefore
need to substitute a specific `self_ty` as needed. This in turn means
that we can't reuse `get_symbol_name` and `get_thunk_name` (because
`Instance::mono` wouldn't work for `Self`-generic methods).
This CL should have no impact on the current end-to-end behavior,
because currently only the `Default` trait is supported and 1) it has
only a single, mandatory method and 2) AFAIK it doesn't have any blanket
impls.
PiperOrigin-RevId: 540883898
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index dac887a..709e3dd 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -10,9 +10,8 @@
use itertools::Itertools;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::{format_ident, quote, ToTokens};
-use rustc_hir::{
- AssocItemKind, Impl, ImplItemKind, ImplicitSelfKind, Item, ItemKind, Node, Unsafety,
-};
+use rustc_hir::{AssocItemKind, ImplItemKind, ImplicitSelfKind, Item, ItemKind, Node, Unsafety};
+use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::dep_graph::DepContext;
use rustc_middle::mir::Mutability;
use rustc_middle::ty::{self, Ty, TyCtxt}; // See <internal link>/ty.html#import-conventions
@@ -20,6 +19,7 @@
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{Abi, FieldsShape, Integer, Layout, Primitive, Scalar};
use rustc_target::spec::PanicStrategy;
+use rustc_trait_selection::infer::InferCtxtExt;
use rustc_type_ir::sty::RegionKind;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;
@@ -667,7 +667,8 @@
}
fn format_region_as_cc_lifetime(region: &ty::Region) -> TokenStream {
- let name = region.get_name().expect("Anonymous regions should be removed by `get_fn_sig`");
+ let name =
+ region.get_name().expect("Caller should use `liberate_and_deanonymize_late_bound_regions`");
let name = name
.as_str()
.strip_prefix('\'')
@@ -678,7 +679,8 @@
}
fn format_region_as_rs_lifetime(region: &ty::Region) -> TokenStream {
- let name = region.get_name().expect("Anonymous regions should be removed by `get_fn_sig`");
+ let name =
+ region.get_name().expect("Caller should use `liberate_and_deanonymize_late_bound_regions`");
let lifetime = syn::Lifetime::new(name.as_str(), proc_macro2::Span::call_site());
quote! { #lifetime }
}
@@ -713,13 +715,13 @@
}
}
-fn get_fn_sig(tcx: TyCtxt, fn_def_id: LocalDefId) -> ty::FnSig {
- let sig = tcx.fn_sig(fn_def_id).subst_identity();
- let fn_def_id = fn_def_id.to_def_id(); // LocalDefId => DefId
-
- // The `replace_late_bound_regions_uncached` call below is similar to
- // `TyCtxt::liberate_and_name_late_bound_regions` but also replaces anonymous
- // regions with new names.
+/// Similar to `TyCtxt::liberate_and_name_late_bound_regions` but also replaces
+/// anonymous regions with new names.
+fn liberate_and_deanonymize_late_bound_regions<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ sig: ty::PolyFnSig<'tcx>,
+ fn_def_id: DefId,
+) -> ty::FnSig<'tcx> {
let mut anon_count: u32 = 0;
let mut translated_kinds: HashMap<ty::BoundVar, ty::BoundRegionKind> = HashMap::new();
tcx.replace_late_bound_regions_uncached(sig, |br: ty::BoundRegion| {
@@ -735,6 +737,12 @@
})
}
+fn get_fn_sig(tcx: TyCtxt, fn_def_id: LocalDefId) -> ty::FnSig {
+ let fn_def_id = fn_def_id.to_def_id(); // LocalDefId => DefId
+ let sig = tcx.fn_sig(fn_def_id).subst_identity();
+ liberate_and_deanonymize_late_bound_regions(tcx, sig, fn_def_id)
+}
+
/// Formats a C++ function declaration of a thunk that wraps a Rust function
/// identified by `fn_def_id`. `format_thunk_impl` may panic if `fn_def_id`
/// doesn't identify a function.
@@ -872,7 +880,9 @@
_ => panic!("Unexpected region kind: {region}"),
})
.sorted_by_key(|region| {
- region.get_name().expect("`get_fn_sig` should remove anonymous lifetimes")
+ region
+ .get_name()
+ .expect("Caller should use `liberate_and_deanonymize_late_bound_regions`")
})
.dedup()
.collect_vec();
@@ -893,22 +903,6 @@
})
}
-fn get_symbol_name<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> Result<&'tcx str> {
- ensure!(
- tcx.generics_of(def_id).count() == 0,
- "Generic functions are not supported yet (b/259749023) - caller should filter them out",
- );
-
- // Call to `mono` is ok - `generics_of` have been checked above.
- let instance = ty::Instance::mono(tcx, def_id.to_def_id());
-
- Ok(tcx.symbol_name(instance).name)
-}
-
-fn get_thunk_name(symbol_name: &str) -> String {
- format!("__crubit_thunk_{}", &escape_non_identifier_chars(symbol_name))
-}
-
fn check_fn_sig(sig: &ty::FnSig) -> Result<()> {
if sig.c_variadic {
// TODO(b/254097223): Add support for variadic functions.
@@ -977,8 +971,16 @@
check_fn_sig(&sig)?;
let needs_thunk = is_thunk_required(&sig).is_err();
let thunk_name = {
- let symbol_name = get_symbol_name(tcx, local_def_id)?;
- if needs_thunk { get_thunk_name(symbol_name) } else { symbol_name.to_string() }
+ let symbol_name = {
+ // Call to `mono` is ok - `generics_of` have been checked above.
+ let instance = ty::Instance::mono(tcx, def_id);
+ tcx.symbol_name(instance).name
+ };
+ if needs_thunk {
+ format!("__crubit_thunk_{}", &escape_non_identifier_chars(symbol_name))
+ } else {
+ symbol_name.to_string()
+ }
};
let fully_qualified_fn_name = FullyQualifiedName::new(tcx, def_id);
@@ -1554,36 +1556,92 @@
ApiSnippets { main_api, cc_details, rs_details }
}
-/// Finds the `Impl` of a trait impl for `self_ty`. Returns an error if the
-/// impl wasn't found.
-///
-/// `self_ty` should specify a *local* type (i.e. type defined in the crate
-/// being "compiled").
-///
-/// `trait_name` should specify the name of a `core` trait - e.g.
-/// [`sym::Default`](https://doc.rust-lang.org/beta/nightly-rustc/rustc_span/symbol/sym/constant.Default.html) is a valid
-/// argument.
-fn find_core_trait_impl<'tcx>(
- tcx: TyCtxt<'tcx>,
- self_ty: Ty<'tcx>,
- trait_name: Symbol,
-) -> Result<&'tcx Impl<'tcx>> {
- let trait_id = tcx
- .get_diagnostic_item(trait_name)
- .expect("`find_core_trait_impl` should only be called with `core`, always-present traits");
- // TODO(b/275387739): Eventually we might need to support blanket impls.
- let mut impls = tcx.non_blanket_impls_for_ty(trait_id, self_ty);
- let impl_id = impls.next();
- if impl_id.is_some() {
- assert_eq!(None, impls.next(), "Expecting only a single trait impl");
+struct TraitThunks {
+ method_name_to_cc_thunk_name: HashMap<Symbol, TokenStream>,
+ cc_thunk_decls: CcSnippet,
+ rs_thunk_impls: TokenStream,
+}
+
+fn format_trait_thunks(
+ input: &Input,
+ trait_id: DefId,
+ adt: &AdtCoreBindings,
+) -> Result<TraitThunks> {
+ let tcx = input.tcx;
+ let self_ty = tcx.type_of(adt.def_id).subst_identity();
+
+ let does_adt_implement_trait = {
+ let generics = tcx.generics_of(trait_id);
+ assert!(generics.has_self);
+ assert_eq!(
+ generics.count(),
+ 1, // Only `Self`
+ "Generic traits are not supported yet (b/286941486)",
+ );
+ let substs = [self_ty];
+
+ tcx.infer_ctxt()
+ .build()
+ .type_implements_trait(trait_id, substs, tcx.param_env(trait_id))
+ .must_apply_modulo_regions()
+ };
+ if !does_adt_implement_trait {
+ let trait_name = tcx.item_name(trait_id);
+ bail!("`{self_ty}` doesn't implement the `{trait_name}` trait");
}
- let impl_id =
- impl_id.ok_or_else(|| anyhow!("`{self_ty}` doesn't implement the `{trait_name}` trait"))?;
- let impl_id = impl_id.expect_local(); // Expecting that `self_ty` is a local type.
- match &tcx.hir().expect_item(impl_id).kind {
- ItemKind::Impl(impl_) => Ok(impl_),
- other => panic!("Unexpected `ItemKind` from `non_blanket_impls_for_ty`: {other:?}"),
+
+ let mut method_name_to_cc_thunk_name = HashMap::new();
+ let mut cc_thunk_decls = CcSnippet::default();
+ let mut rs_thunk_impls = quote! {};
+ let methods = tcx
+ .associated_items(trait_id)
+ .in_definition_order()
+ .filter(|item| item.kind == ty::AssocKind::Fn);
+ for method in methods {
+ let substs = {
+ let generics = tcx.generics_of(method.def_id);
+ if generics.params.iter().any(|p| p.kind.is_ty_or_const()) {
+ // Note that lifetime-generic methods are ok:
+ // * they are handled by `format_thunk_decl` and `format_thunk_impl`
+ // * the lifetimes are erased by `ty::Instance::mono` and *seem* to be erased by
+ // `ty::Instance::new`
+ panic!(
+ "So far callers of `format_trait_thunks` didn't need traits with \
+ methods that are type-generic or const-generic"
+ );
+ }
+ assert!(generics.has_self);
+ tcx.mk_substs_trait(self_ty, std::iter::empty())
+ };
+
+ let thunk_name = {
+ let instance = ty::Instance::new(method.def_id, substs);
+ let symbol = tcx.symbol_name(instance);
+ format!("__crubit_thunk_{}", &escape_non_identifier_chars(symbol.name))
+ };
+ method_name_to_cc_thunk_name.insert(method.name, format_cc_ident(&thunk_name)?);
+
+ let sig = tcx.fn_sig(method.def_id).subst(tcx, substs);
+ let sig = liberate_and_deanonymize_late_bound_regions(tcx, sig, method.def_id);
+
+ cc_thunk_decls.add_assign({
+ let thunk_name = format_cc_ident(&thunk_name)?;
+ format_thunk_decl(input, method.def_id, &sig, &thunk_name)?
+ });
+
+ rs_thunk_impls.extend({
+ let fully_qualified_fn_name = {
+ let struct_name = &adt.rs_fully_qualified_name;
+ let fully_qualified_trait_name =
+ FullyQualifiedName::new(tcx, trait_id).format_for_rs();
+ let method_name = make_rs_ident(method.name.as_str());
+ quote! { <#struct_name as #fully_qualified_trait_name>::#method_name }
+ };
+ format_thunk_impl(tcx, method.def_id, &sig, &thunk_name, fully_qualified_fn_name)?
+ });
}
+
+ Ok(TraitThunks { method_name_to_cc_thunk_name, cc_thunk_decls, rs_thunk_impls })
}
/// Formats a default constructor for an ADT if possible (i.e. if the `Default`
@@ -1591,37 +1649,33 @@
/// there is no `Default` impl).
fn format_default_ctor(input: &Input, core: &AdtCoreBindings) -> Result<ApiSnippets> {
let tcx = input.tcx;
- let ty = tcx.type_of(core.def_id).subst_identity();
+ let trait_id =
+ tcx.get_diagnostic_item(sym::Default).expect("`Default` trait should always be present");
+ let TraitThunks { method_name_to_cc_thunk_name, cc_thunk_decls, rs_thunk_impls: rs_details } =
+ format_trait_thunks(input, trait_id, core)?;
- let trait_impl = find_core_trait_impl(input.tcx, ty, sym::Default)?;
- assert_eq!(trait_impl.items.len(), 1, "Only the `default` method is expected");
- assert_eq!(trait_impl.items[0].ident.name.as_str(), "default");
let cc_struct_name = &core.cc_short_name;
let main_api = CcSnippet::new(quote! {
__NEWLINE__ __COMMENT__ "Default::default"
inline #cc_struct_name(); __NEWLINE__ __NEWLINE__
});
- let fn_def_id = trait_impl.items[0].id.owner_id.def_id;
- let sig = get_fn_sig(tcx, fn_def_id);
- let thunk_name = get_thunk_name(get_symbol_name(tcx, fn_def_id)?);
let cc_details = {
- let thunk_name = format_cc_ident(&thunk_name)?;
- let CcSnippet { tokens: thunk_decl, prereqs } =
- format_thunk_decl(input, fn_def_id.to_def_id(), &sig, &thunk_name)?;
+ let thunk_name = method_name_to_cc_thunk_name
+ .into_values()
+ .exactly_one()
+ .expect("Expecting a single `default` method");
+
+ let mut prereqs = CcPrerequisites::default();
+ let cc_thunk_decls = cc_thunk_decls.into_tokens(&mut prereqs);
+
let tokens = quote! {
- #thunk_decl
+ #cc_thunk_decls
#cc_struct_name::#cc_struct_name() {
__crubit_internal::#thunk_name(this);
}
};
CcSnippet { tokens, prereqs }
};
- let rs_details = {
- let struct_name = &core.rs_fully_qualified_name;
- let fully_qualified_fn_name =
- quote! { <#struct_name as ::core::default::Default>::default };
- format_thunk_impl(tcx, fn_def_id.to_def_id(), &sig, &thunk_name, fully_qualified_fn_name)?
- };
Ok(ApiSnippets { main_api, cc_details, rs_details })
}
@@ -1664,7 +1718,12 @@
// TODO(b/259741191): Implement bindings for `Clone::clone` and
// `Clone::clone_from`.
- let _trait_impl = find_core_trait_impl(tcx, ty, sym::Clone)?;
+ let trait_id =
+ tcx.lang_items().clone_trait().ok_or_else(|| anyhow!("Can't find the `Clone` trait"))?;
+ let self_ty = tcx.type_of(core.def_id).subst_identity();
+ tcx.non_blanket_impls_for_ty(trait_id, self_ty)
+ .next()
+ .ok_or_else(|| anyhow!("`{self_ty}` doesn't implement the `Clone` trait"))?;
bail!("Bindings for the `Clone` trait are not supported yet (b/259741191)");
}
diff --git a/cc_bindings_from_rs/cc_bindings_from_rs.rs b/cc_bindings_from_rs/cc_bindings_from_rs.rs
index ab24ce3..700f0bb 100644
--- a/cc_bindings_from_rs/cc_bindings_from_rs.rs
+++ b/cc_bindings_from_rs/cc_bindings_from_rs.rs
@@ -12,12 +12,14 @@
extern crate rustc_errors;
extern crate rustc_feature;
extern crate rustc_hir;
+extern crate rustc_infer;
extern crate rustc_interface;
extern crate rustc_lint_defs;
extern crate rustc_middle;
extern crate rustc_session;
extern crate rustc_span;
extern crate rustc_target;
+extern crate rustc_trait_selection;
extern crate rustc_type_ir;
// TODO(b/254679226): `bindings`, `cmdline`, and `run_compiler` should be