blob: 7aded22017f4631dcd1404d8431bbd4ded3578f4 [file] [log] [blame]
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#![cfg_attr(not(test), no_std)]
extern crate alloc;
use alloc::borrow::Cow;
use alloc::collections::BTreeSet;
use alloc::format;
use alloc::vec;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens as _};
use syn::parse::Parse;
use syn::spanned::Spanned as _;
use syn::Token;
// TODO(jeanpierreda): derive constructors and assignment for copy and move.
const FIELD_FOR_MUST_USE_CTOR: &'static str = "__must_use_ctor_to_initialize";
#[proc_macro_derive(CtorFrom_Default)]
pub fn derive_default(item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::DeriveInput);
let struct_name = input.ident;
let struct_ctor_name =
Ident::new(&format!("_ctor_derive_{}_CtorType_Default", struct_name), Span::call_site());
let fields: proc_macro2::TokenStream = match &input.data {
syn::Data::Struct(data) => {
if let syn::Fields::Unit = data.fields {
quote! {}
} else {
let filled_fields = data.fields.iter().enumerate().filter_map(|(i, field)| {
let field_i = syn::Index::from(i);
let field_name;
// This logic is here in case you derive default on the output of
// `#[recursively_pinned]`, but it's obviously not very flexible. For example,
// maybe we want to compute a non-colliding field name, and maybe there are
// other ordering problems.
match &field.ident {
Some(name) if name == FIELD_FOR_MUST_USE_CTOR => return None,
Some(name) => field_name = quote! {#name},
None => field_name = quote! {#field_i},
};
let field_type = &field.ty;
Some(quote_spanned! {field.span() =>
#field_name: <#field_type as ::ctor::CtorNew<()>>::ctor_new(())
})
});
quote! {{ #(#filled_fields),* }}
}
}
syn::Data::Enum(e) => {
return syn::Error::new(e.enum_token.span, "Enums are not supported")
.into_compile_error()
.into();
}
syn::Data::Union(u) => {
return syn::Error::new(u.union_token.span, "Unions are not supported")
.into_compile_error()
.into();
}
};
let expanded = quote! {
struct #struct_ctor_name();
impl ::ctor::Ctor for #struct_ctor_name {
type Output = #struct_name;
unsafe fn ctor(self, dest: ::core::pin::Pin<&mut ::core::mem::MaybeUninit<Self::Output>>) {
::ctor::ctor!(
#struct_name #fields
).ctor(dest)
}
}
impl !::core::marker::Unpin for #struct_ctor_name {}
impl ::ctor::CtorNew<()> for #struct_name {
type CtorType = #struct_ctor_name;
fn ctor_new(_args: ()) -> #struct_ctor_name { #struct_ctor_name() }
}
};
TokenStream::from(expanded)
}
/// `project_pin_type!(foo::T)` is the name of the type returned by
/// `foo::T::project_pin()`.
///
/// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
/// would have used, but is essentially useless.
#[proc_macro]
pub fn project_pin_type(name: TokenStream) -> TokenStream {
let mut name = syn::parse_macro_input!(name as syn::Path);
match name.segments.last_mut() {
None => {
return syn::Error::new(name.span(), "Path must have at least one element")
.into_compile_error()
.into();
}
Some(last) => {
if let syn::PathArguments::Parenthesized(p) = &last.arguments {
return syn::Error::new(
p.span(),
"Parenthesized paths (e.g. fn, Fn) do not have projected equivalents.",
)
.into_compile_error()
.into();
}
last.ident = project_pin_ident(&last.ident);
}
}
TokenStream::from(quote! { #name })
}
fn project_pin_ident(ident: &Ident) -> Ident {
Ident::new(&format!("__CrubitProjectPin{}", ident), Span::call_site())
}
/// Defines the `project_pin` function, and its return value.
///
/// If the input is a union, this returns nothing, and pin-projection is not
/// implemented.
fn project_pin_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let is_fieldless = match &input.data {
syn::Data::Struct(data) => data.fields.is_empty(),
syn::Data::Enum(e) => e.variants.iter().all(|variant| variant.fields.is_empty()),
syn::Data::Union(_) => {
return Ok(quote! {});
}
};
let mut projected = input.clone();
// TODO(jeanpierreda): check attributes for repr(packed)
projected.attrs.clear();
projected.ident = project_pin_ident(&projected.ident);
let lifetime = if is_fieldless {
quote! {}
} else {
add_lifetime(&mut projected.generics, "'proj")
};
let project_field = |field: &mut syn::Field| {
field.attrs.clear();
let field_ty = &field.ty;
let pin_ty = syn::parse_quote!(::core::pin::Pin<& #lifetime mut #field_ty>);
field.ty = syn::Type::Path(pin_ty);
};
// returns the braced parts of a projection pattern and return value.
// e.g. {foo, bar, ..}, {foo: Pin::new_unchecked(foo), bar:
// Pin::new_unchecked(bar)}
let pat_project = |fields: &mut syn::Fields| {
let mut pat = quote! {};
let mut project = quote! {};
for (i, field) in fields.iter_mut().enumerate() {
// TODO(jeanpierreda): check attributes for e.g. #[unpin]
field.attrs.clear();
let lhs;
let rhs;
if let Some(ident) = &field.ident {
lhs = quote! {#ident};
rhs = ident.clone();
pat.extend(quote! {#lhs,});
} else {
lhs = proc_macro2::Literal::usize_unsuffixed(i).into_token_stream();
rhs = Ident::new(&format!("item_{i}"), Span::call_site());
pat.extend(quote! {#lhs: #rhs,});
}
project.extend(quote! {#lhs: ::core::pin::Pin::new_unchecked(#rhs),});
}
// Also ignore the __must_use_ctor_to_initialize field, if present.
pat.extend(quote! {..});
(quote! {{#pat}}, quote! {{#project}})
};
let project_body;
let input_ident = &input.ident;
let projected_ident = &projected.ident;
match &mut projected.data {
syn::Data::Struct(data) => {
for field in &mut data.fields {
project_field(field);
}
let (pat, project) = pat_project(&mut data.fields);
project_body = quote! {
let #input_ident #pat = from;
#projected_ident #project
};
}
syn::Data::Enum(e) => {
let mut match_body = quote! {};
for variant in &mut e.variants {
for field in &mut variant.fields {
project_field(field);
}
let (pat, project) = pat_project(&mut variant.fields);
let variant_ident = &variant.ident;
match_body.extend(quote! {
#input_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
});
}
project_body = quote! {
match from {
#match_body
}
};
}
syn::Data::Union(_) => {
unreachable!("project_pin_impl should early return when it finds a union")
}
}
let (input_impl_generics, input_ty_generics, input_where_clause) =
input.generics.split_for_impl();
let (_, projected_generics, _) = projected.generics.split_for_impl();
Ok(quote! {
#projected
impl #input_impl_generics #input_ident #input_ty_generics #input_where_clause {
#[must_use]
pub fn project_pin<#lifetime>(self: ::core::pin::Pin<& #lifetime mut Self>) -> #projected_ident #projected_generics {
unsafe {
let from = ::core::pin::Pin::into_inner_unchecked(self);
#project_body
}
}
}
})
}
/// Adds a new lifetime to `generics`, returning the quoted lifetime name.
fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::TokenStream {
let taken_lifetimes: BTreeSet<&syn::Lifetime> =
generics.lifetimes().map(|def| &def.lifetime).collect();
let mut name = Cow::Borrowed(prefix);
let mut i = 1;
let lifetime = loop {
let lifetime = syn::Lifetime::new(&name, Span::call_site());
if !taken_lifetimes.contains(&lifetime) {
break lifetime;
}
i += 1;
name = Cow::Owned(format!("{prefix}_{i}"));
};
let quoted_lifetime = quote! {#lifetime};
generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(lifetime)));
quoted_lifetime
}
#[derive(Default)]
struct RecursivelyPinnedArgs {
is_pinned_drop: bool,
}
impl Parse for RecursivelyPinnedArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let args = <syn::punctuated::Punctuated<Ident, Token![,]>>::parse_terminated(input)?;
if args.len() > 1 {
return Err(syn::Error::new(
input.span(), // not args.span(), as that is only for the first argument.
&format!("expected at most 1 argument, got: {}", args.len()),
));
}
let is_pinned_drop = if let Some(arg) = args.first() {
if arg != "PinnedDrop" {
return Err(syn::Error::new(
arg.span(),
"unexpected argument (wasn't `PinnedDrop`)",
));
}
true
} else {
false
};
Ok(RecursivelyPinnedArgs { is_pinned_drop })
}
}
/// Prevents this type from being directly created outside of this crate in safe
/// code.
///
/// For enums and unit structs, this uses the `#[non_exhaustive]` attribute.
/// This leads to unfortunate error messages, but there is no other way to
/// prevent creation of an enum or a unit struct at this time.
///
/// For tuple structs, we also use `#[non_exhaustive]`, as it's no worse than
/// the alternative. Both adding a private field and adding `#[non_exhaustive]`
/// lead to indirect error messages, but `#[non_exhaustive]` is the more likely
/// of the two to ever get custom error message support.
///
/// Finally, for structs with named fields, we actually *cannot* use
/// `#[non_exhaustive]`, because it would make the struct not FFI-safe, and
/// structs with named fields are specifically supported for C++ interop.
/// Instead, we use a private field with a name that indicates the error.
/// (`__must_use_ctor_to_initialize`).
///
/// Unions are not yet implemented properly.
///
/// ---
///
/// Note that the use of `#[non_exhaustive]` also has other effects. At the
/// least: tuple variants and tuple structs marked with `#[non_exhaustive]`
/// cannot be pattern matched using the "normal" syntax. Instead, one must use
/// curly braces. (Broken: `T(x, ..)`; woken: `T{0: x, ..}`).
///
/// (This does not seem very intentional, and with all luck will be fixed before
/// too long.)
fn forbid_initialization(s: &mut syn::DeriveInput) {
let non_exhaustive_attr = syn::parse_quote!(#[non_exhaustive]);
match &mut s.data {
// TODO(b/232969667): prevent creation of unions from safe code.
// (E.g. hide inside a struct.)
syn::Data::Union(_) => return,
syn::Data::Struct(data) => {
match &mut data.fields {
syn::Fields::Unit | syn::Fields::Unnamed(_) => {
s.attrs.insert(0, non_exhaustive_attr);
}
syn::Fields::Named(fields) => {
fields.named.push(syn::Field {
attrs: vec![],
vis: syn::Visibility::Inherited,
// TODO(jeanpierreda): better hygiene: work even if a field has the same name.
ident: Some(Ident::new(FIELD_FOR_MUST_USE_CTOR, Span::call_site())),
colon_token: Some(<syn::Token![:]>::default()),
ty: syn::parse_quote!([u8; 0]),
});
}
}
}
syn::Data::Enum(e) => {
// Enums can't have private fields. Instead, we need to add #[non_exhaustive] to
// every variant -- this makes it impossible to construct the
// variants.
for variant in &mut e.variants {
variant.attrs.insert(0, non_exhaustive_attr.clone());
}
}
}
}
/// `#[recursively_pinned]` pins every field, similar to `#[pin_project]`, and
/// marks the struct `!Unpin`.
///
/// Example:
///
/// ```
/// #[recursively_pinned]
/// struct S {
/// field: i32,
/// }
/// ```
///
/// This is analogous to using pin_project, pinning every field, as so:
///
/// ```
/// #[pin_project(!Unpin)]
/// struct S {
/// #[pin]
/// field: i32,
/// }
/// ```
///
/// ## Arguments
///
/// ### `PinnedDrop`
///
/// To define a destructor for a recursively-pinned struct, pass `PinnedDrop`
/// and implement the `PinnedDrop` trait.
///
/// `#[recursively_pinned]` prohibits implementing `Drop`, as that would make it
/// easy to violate the `Pin` guarantee. Instead, to define a destructor, one
/// must define a `PinnedDrop` impl, as so:
///
/// ```
/// #[recursively_pinned(PinnedDrop)]
/// struct S {
/// field: i32,
/// }
///
/// impl PinnedDrop for S {
/// unsafe fn pinned_drop(self: Pin<&mut Self>) {
/// println!("I am being destroyed!");
/// }
/// }
/// ```
///
/// (This is analogous to `#[pin_project(PinnedDrop)]`.)
///
/// ## Direct initialization
///
/// Use the `ctor!` macro to instantiate recursively pinned types. For example:
///
/// ```
/// // equivalent to `let x = Point {x: 3, y: 4}`, but uses pinned construction.
/// emplace! {
/// let x = ctor!(Point {x: 3, y: 4});
/// }
/// ```
///
/// Recursively pinned types cannot be created directly in safe code, as they
/// are pinned from the very moment of their creation.
///
/// This is prevented either using `#[non_exhaustive]` or using a private field,
/// depending on the type in question. For example, enums use
/// `#[non_exhaustive]`, and structs with named fields use a private field named
/// `__must_use_ctor_to_initialize`. This can lead to confusing error messages,
/// so watch out!
///
/// ## Supported types
///
/// Structs, enums, and unions are all supported. However, unions do not receive
/// a `pin_project` method, as there is no way to implement pin projection for
/// unions. (One cannot know which field is active.)
#[proc_macro_attribute]
pub fn recursively_pinned(args: TokenStream, item: TokenStream) -> TokenStream {
match recursively_pinned_impl(args.into(), item.into()) {
Ok(t) => t.into(),
Err(e) => e.into_compile_error().into(),
}
}
/// A separate function for calling from tests.
///
/// See e.g. https://users.rust-lang.org/t/procedural-macro-api-is-used-outside-of-a-procedural-macro/30841
fn recursively_pinned_impl(
args: proc_macro2::TokenStream,
item: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
let args = syn::parse2::<RecursivelyPinnedArgs>(args)?;
let mut input = syn::parse2::<syn::DeriveInput>(item)?;
let project_pin_impl = project_pin_impl(&input)?;
let name = input.ident.clone();
// Create two copies of input: one (public) has a private field that can't be
// instantiated. The other (only visible via
// RecursivelyPinned::CtorInitializedFields) doesn't have this field.
// This causes `ctor!(Foo {})` to work, but `Foo{}` to complain of a missing
// field.
let mut ctor_initialized_input = input.clone();
// Removing repr(C) triggers dead-code detection.
ctor_initialized_input.attrs = vec![syn::parse_quote!(#[allow(dead_code)])];
// TODO(jeanpierreda): This should really check for name collisions with any types
// used in the fields. Collisions with other names don't matter, because the
// type is locally defined within a narrow scope.
ctor_initialized_input.ident = syn::Ident::new(&format!("__CrubitCtor{name}"), name.span());
let ctor_initialized_name = &ctor_initialized_input.ident;
forbid_initialization(&mut input);
let (input_impl_generics, input_ty_generics, input_where_clause) =
input.generics.split_for_impl();
let drop_impl = if args.is_pinned_drop {
quote! {
impl #input_impl_generics Drop for #name #input_ty_generics #input_where_clause {
fn drop(&mut self) {
unsafe {::ctor::PinnedDrop::pinned_drop(::core::pin::Pin::new_unchecked(self))}
}
}
}
} else {
quote! {
impl #input_impl_generics ::ctor::macro_internal::DoNotImplDrop for #name #input_ty_generics #input_where_clause {}
/// A no-op PinnedDrop that will cause an error if the user also defines PinnedDrop,
/// due to forgetting to pass `PinnedDrop` to #[recursively_pinned(PinnedDrop)]`.
impl #input_impl_generics ::ctor::PinnedDrop for #name #input_ty_generics #input_where_clause {
unsafe fn pinned_drop(self: ::core::pin::Pin<&mut Self>) {}
}
}
};
Ok(quote! {
#input
#project_pin_impl
#drop_impl
impl #input_impl_generics !Unpin for #name #input_ty_generics #input_where_clause {}
// Introduce a new scope to limit the blast radius of the CtorInitializedFields type.
// This lets us use relatively readable names: while the impl is visible outside the scope,
// type is otherwise not visible.
const _ : () = {
#ctor_initialized_input
unsafe impl #input_impl_generics ::ctor::RecursivelyPinned for #name #input_ty_generics #input_where_clause {
type CtorInitializedFields = #ctor_initialized_name #input_ty_generics;
}
};
})
}
#[cfg(test)]
mod test {
use super::*;
use token_stream_matchers::assert_rs_matches;
/// Essentially a change detector, but handy for debugging.
///
/// At time of writing, we can't write negative compilation tests, so
/// asserting on the output is as close as we can get. Once negative
/// compilation tests are added, it would be better to test various
/// safety features that way.
#[test]
fn test_recursively_pinned_struct() {
let definition =
recursively_pinned_impl(quote! {}, quote! {#[repr(C)] struct S {x: i32}}).unwrap();
// The struct can't be directly created, but can be created via
// CtorInitializedFields:
assert_rs_matches!(
definition,
quote! {
#[repr(C)]
struct S {
x: i32,
__must_use_ctor_to_initialize: [u8; 0]
}
}
);
assert_rs_matches!(
definition,
quote! {
const _: () = {
#[allow(dead_code)]
struct __CrubitCtorS {x: i32}
unsafe impl ::ctor::RecursivelyPinned for S {
type CtorInitializedFields = __CrubitCtorS;
}
};
}
);
// The type is non-Unpin:
assert_rs_matches!(
definition,
quote! {
impl !Unpin for S {}
}
);
// The remaining features of the generated output are better tested via
// real tests that exercise the code.
}
/// The enum version of `test_recursively_pinned_struct`.
#[test]
fn test_recursively_pinned_enum() {
let definition = recursively_pinned_impl(
quote! {},
quote! {
#[repr(C)]
enum E {
A,
B(i32),
}
},
)
.unwrap();
// The enum variants can't be directly created, but can be created via
// CtorInitializedFields:
assert_rs_matches!(
definition,
quote! {
#[repr(C)]
enum E {
#[non_exhaustive]
A,
#[non_exhaustive]
B(i32),
}
}
);
assert_rs_matches!(
definition,
quote! {
const _: () = {
#[allow(dead_code)]
enum __CrubitCtorE {
A,
B(i32),
}
unsafe impl ::ctor::RecursivelyPinned for E {
type CtorInitializedFields = __CrubitCtorE;
}
};
}
);
// The type is non-Unpin:
assert_rs_matches!(
definition,
quote! {
impl !Unpin for E {}
}
);
// The remaining features of the generated output are better tested via
// real tests that exercise the code.
}
}