Support `#[recursively_pinned]` generics.
This is mostly only important for its interaction with lifetimes. We don't expect to support generic types in C++ interop, not as Rust generics -- but we do expect, I think, to support C++ classes with lifetime parameters, which would become generic parameters.
I had put off implementing it because it introduces the possibility of name collision, but now is as good a time as any to implement it. (It will help fix up tests in one of the other CLs I'm working on.)
I'm actually really impressed by the thoroughness of `syn` here. Even if it doesn't automatically give me hygiene, it does have all the facilities I need. It looks ugly, but nonetheless I am impressed by e.g. https://docs.rs/syn/latest/syn/struct.Generics.html#method.split_for_impl .
---
P.S. I have *no idea* how to implement `add_lifetime()` in a nice, copy-free, readable way. I know how I'd do it in Python, but...
This CL contains my best attempt (using `Cow`). Alternate idea was to move the logic into a function, something like:
```rs
fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::TokenStream {
let taken_lifetimes: HashSet<&syn::Lifetime> = generics.lifetimes().map(|def| &def.lifetime).collect();
let try_lifetime = |name: &str| -> Option<syn::Lifetime> {
let lifetime = syn::Lifetime::new(name, Span::call_site());
if taken_lifetimes.contains(&lifetime) { None } else { Some(lifetime) }
};
let lifetime = try_lifetime(prefix).unwrap_or_else(|| {
(2..).map(|n| format!("{prefix}_{n}")).find_map(|s| try_lifetime(&s)).unwrap()
});
let quoted_lifetime = quote! {#lifetime};
generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(lifetime)));
quoted_lifetime
}
```
But this is repellant to me, tbh.
PiperOrigin-RevId: 454042762
diff --git a/rs_bindings_from_cc/support/ctor_proc_macros.rs b/rs_bindings_from_cc/support/ctor_proc_macros.rs
index 375825c..bc5447e 100644
--- a/rs_bindings_from_cc/support/ctor_proc_macros.rs
+++ b/rs_bindings_from_cc/support/ctor_proc_macros.rs
@@ -5,6 +5,8 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens as _};
+use std::borrow::Cow;
+use std::collections::HashSet;
use syn::parse::Parse;
use syn::spanned::Spanned as _;
use syn::Token;
@@ -110,8 +112,8 @@
///
/// If the input is a union, this returns nothing, and pin-projection is not
/// implemented.
-fn project_pin_impl(s: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
- let is_fieldless = match &s.data {
+fn project_pin_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
+ let is_fieldless = match &input.data {
syn::Data::Struct(data) => data.fields.is_empty(),
syn::Data::Enum(e) => e.variants.iter().all(|variant| variant.fields.is_empty()),
syn::Data::Union(_) => {
@@ -119,29 +121,16 @@
}
};
- let mut projected = s.clone();
+ let mut projected = input.clone();
// TODO(jeanpierreda): check attributes for repr(packed)
projected.attrs.clear();
projected.ident = project_pin_ident(&projected.ident);
- if projected.generics.params.len() != 0 {
- return Err(syn::Error::new(
- projected.generics.span(),
- "projection is currently not implemented for generic structs",
- ));
- }
-
- let lifetime;
- if is_fieldless {
- lifetime = quote! {};
+ let lifetime = if is_fieldless {
+ quote! {}
} else {
- let syn_lifetime = syn::Lifetime::new("'proj", Span::call_site());
- projected
- .generics
- .params
- .push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(syn_lifetime.clone())));
- lifetime = quote! {#syn_lifetime};
- }
+ add_lifetime(&mut projected.generics, "'proj")
+ };
let project_field = |field: &mut syn::Field| {
field.attrs.clear();
@@ -173,7 +162,7 @@
(quote! {{#pat}}, quote! {{#project}})
};
let project_body;
- let original_ident = &s.ident;
+ let input_ident = &input.ident;
let projected_ident = &projected.ident;
match &mut projected.data {
syn::Data::Struct(data) => {
@@ -182,7 +171,7 @@
}
let (pat, project) = pat_project(&mut data.fields);
project_body = quote! {
- let #original_ident #pat = from;
+ let #input_ident #pat = from;
#projected_ident #project
};
}
@@ -195,7 +184,7 @@
let (pat, project) = pat_project(&mut variant.fields);
let variant_ident = &variant.ident;
match_body.extend(quote! {
- #original_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
+ #input_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
});
}
project_body = quote! {
@@ -204,15 +193,21 @@
}
};
}
- syn::Data::Union(_) => unreachable!("project_pin_impl should early return when it finds a union"),
+ syn::Data::Union(_) => {
+ unreachable!("project_pin_impl should early return when it finds a union")
+ }
}
+ let (input_impl_generics, input_ty_generics, input_where_clause) =
+ input.generics.split_for_impl();
+ let (_, projected_generics, _) = projected.generics.split_for_impl();
+
Ok(quote! {
#projected
- impl #original_ident {
+ impl #input_impl_generics #input_ident #input_ty_generics #input_where_clause {
#[must_use]
- pub fn project_pin(self: ::std::pin::Pin<&mut Self>) -> #projected_ident {
+ pub fn project_pin<#lifetime>(self: ::std::pin::Pin<& #lifetime mut Self>) -> #projected_ident #projected_generics {
unsafe {
let from = ::std::pin::Pin::into_inner_unchecked(self);
#project_body
@@ -222,6 +217,26 @@
})
}
+/// Adds a new lifetime to `generics`, returning the quoted lifetime name.
+fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::TokenStream {
+ let taken_lifetimes: HashSet<&syn::Lifetime> =
+ generics.lifetimes().map(|def| &def.lifetime).collect();
+ let mut name = Cow::Borrowed(prefix);
+ let mut i = 1;
+ let lifetime = loop {
+ let lifetime = syn::Lifetime::new(&name, Span::call_site());
+ if !taken_lifetimes.contains(&lifetime) {
+ break lifetime;
+ }
+
+ i += 1;
+ name = Cow::Owned(format!("{prefix}_{i}"));
+ };
+ let quoted_lifetime = quote! {#lifetime};
+ generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(lifetime)));
+ quoted_lifetime
+}
+
#[derive(Default)]
struct RecursivelyPinnedArgs {
is_pinned_drop: bool,
@@ -316,9 +331,12 @@
let name = input.ident.clone();
+ let (input_impl_generics, input_ty_generics, input_where_clause) =
+ input.generics.split_for_impl();
+
let drop_impl = if args.is_pinned_drop {
quote! {
- impl Drop for #name {
+ impl #input_impl_generics Drop for #name #input_ty_generics #input_where_clause {
fn drop(&mut self) {
unsafe {::ctor::PinnedDrop::pinned_drop(::std::pin::Pin::new_unchecked(self))}
}
@@ -326,10 +344,10 @@
}
} else {
quote! {
- impl ::ctor::macro_internal::DoNotImplDrop for #name {}
+ impl #input_impl_generics ::ctor::macro_internal::DoNotImplDrop for #name #input_ty_generics #input_where_clause {}
/// A no-op PinnedDrop that will cause an error if the user also defines PinnedDrop,
/// due to forgetting to pass `PinnedDrop` to #[recursively_pinned(PinnedDrop)]`.
- impl ::ctor::PinnedDrop for #name {
+ impl #input_impl_generics ::ctor::PinnedDrop for #name #input_ty_generics #input_where_clause {
unsafe fn pinned_drop(self: ::std::pin::Pin<&mut Self>) {}
}
}
@@ -341,8 +359,8 @@
#drop_impl
- unsafe impl ::ctor::RecursivelyPinned for #name {}
- impl !Unpin for #name {}
+ unsafe impl #input_impl_generics ::ctor::RecursivelyPinned for #name #input_ty_generics #input_where_clause {}
+ impl #input_impl_generics !Unpin for #name #input_ty_generics #input_where_clause {}
};
TokenStream::from(expanded)
diff --git a/rs_bindings_from_cc/support/ctor_proc_macros_test.rs b/rs_bindings_from_cc/support/ctor_proc_macros_test.rs
index bb19d01..ba66b1b 100644
--- a/rs_bindings_from_cc/support/ctor_proc_macros_test.rs
+++ b/rs_bindings_from_cc/support/ctor_proc_macros_test.rs
@@ -143,6 +143,22 @@
}
#[test]
+fn test_recursively_pinned_generic() {
+ #[::ctor::recursively_pinned]
+ struct S<'proj, 'proj_2: 'proj, 'proj_4, T>
+ where
+ 'proj_4: 'proj_2,
+ {
+ x: T,
+ /// 'proj* are not really used, but exist to try to throw a wrench in
+ /// the works.
+ _phantom: ::std::marker::PhantomData<&'proj &'proj_2 &'proj_4 T>,
+ }
+ let _: ::std::pin::Pin<&mut i32> =
+ Box::pin(S::<i32> { x: 42, _phantom: ::std::marker::PhantomData }).as_mut().project_pin().x;
+}
+
+#[test]
fn test_recursively_pinned_struct_derive_default() {
#[::ctor::recursively_pinned]
#[derive(::ctor::CtorFrom_Default)]