// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#![cfg_attr(not(test), no_std)]

extern crate alloc;

use alloc::borrow::Cow;
use alloc::collections::BTreeSet;
use alloc::format;
use alloc::vec;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, quote_spanned, ToTokens as _};
use syn::parse::Parse;
use syn::spanned::Spanned as _;
use syn::Token;

// TODO(jeanpierreda): derive constructors and assignment for copy and move.

const FIELD_FOR_MUST_USE_CTOR: &'static str = "__must_use_ctor_to_initialize";

#[proc_macro_derive(CtorFrom_Default)]
pub fn derive_default(item: TokenStream) -> TokenStream {
    let input = syn::parse_macro_input!(item as syn::DeriveInput);

    let struct_name = input.ident;
    let struct_ctor_name =
        Ident::new(&format!("_ctor_derive_{}_CtorType_Default", struct_name), Span::call_site());
    let fields: proc_macro2::TokenStream = match &input.data {
        syn::Data::Struct(data) => {
            if let syn::Fields::Unit = data.fields {
                quote! {}
            } else {
                let filled_fields = data.fields.iter().enumerate().filter_map(|(i, field)| {
                    let field_i = syn::Index::from(i);
                    let field_name;
                    // This logic is here in case you derive default on the output of
                    // `#[recursively_pinned]`, but it's obviously not very flexible. For example,
                    // maybe we want to compute a non-colliding field name, and maybe there are
                    // other ordering problems.
                    match &field.ident {
                        Some(name) if name == FIELD_FOR_MUST_USE_CTOR => return None,
                        Some(name) => field_name = quote! {#name},
                        None => field_name = quote! {#field_i},
                    };

                    let field_type = &field.ty;
                    Some(quote_spanned! {field.span() =>
                        #field_name: <#field_type as ::ctor::CtorNew<()>>::ctor_new(())
                    })
                });
                quote! {{ #(#filled_fields),* }}
            }
        }
        syn::Data::Enum(e) => {
            return syn::Error::new(e.enum_token.span, "Enums are not supported")
                .into_compile_error()
                .into();
        }
        syn::Data::Union(u) => {
            return syn::Error::new(u.union_token.span, "Unions are not supported")
                .into_compile_error()
                .into();
        }
    };

    let expanded = quote! {
        struct #struct_ctor_name();

        impl ::ctor::Ctor for #struct_ctor_name {
            type Output = #struct_name;
            unsafe fn ctor(self, dest: ::core::pin::Pin<&mut ::core::mem::MaybeUninit<Self::Output>>) {
                ::ctor::ctor!(
                    #struct_name #fields
                ).ctor(dest)
            }
        }

        impl !::core::marker::Unpin for #struct_ctor_name {}

        impl ::ctor::CtorNew<()> for #struct_name {
            type CtorType = #struct_ctor_name;

            fn ctor_new(_args: ()) -> #struct_ctor_name { #struct_ctor_name() }
        }
    };
    TokenStream::from(expanded)
}

/// `project_pin_type!(foo::T)` is the name of the type returned by
/// `foo::T::project_pin()`.
///
/// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
/// would have used, but is essentially useless.
#[proc_macro]
pub fn project_pin_type(name: TokenStream) -> TokenStream {
    let mut name = syn::parse_macro_input!(name as syn::Path);
    match name.segments.last_mut() {
        None => {
            return syn::Error::new(name.span(), "Path must have at least one element")
                .into_compile_error()
                .into();
        }
        Some(last) => {
            if let syn::PathArguments::Parenthesized(p) = &last.arguments {
                return syn::Error::new(
                    p.span(),
                    "Parenthesized paths (e.g. fn, Fn) do not have projected equivalents.",
                )
                .into_compile_error()
                .into();
            }
            last.ident = project_pin_ident(&last.ident);
        }
    }
    TokenStream::from(quote! { #name })
}

fn project_pin_ident(ident: &Ident) -> Ident {
    Ident::new(&format!("__CrubitProjectPin{}", ident), Span::call_site())
}

/// Defines the `project_pin` function, and its return value.
///
/// If the input is a union, this returns nothing, and pin-projection is not
/// implemented.
fn project_pin_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let is_fieldless = match &input.data {
        syn::Data::Struct(data) => data.fields.is_empty(),
        syn::Data::Enum(e) => e.variants.iter().all(|variant| variant.fields.is_empty()),
        syn::Data::Union(_) => {
            return Ok(quote! {});
        }
    };

    let mut projected = input.clone();
    // TODO(jeanpierreda): check attributes for repr(packed)
    projected.attrs.clear();
    projected.ident = project_pin_ident(&projected.ident);

    let lifetime = if is_fieldless {
        quote! {}
    } else {
        add_lifetime(&mut projected.generics, "'proj")
    };

    let project_field = |field: &mut syn::Field| {
        field.attrs.clear();
        let field_ty = &field.ty;
        let pin_ty = syn::parse_quote!(::core::pin::Pin<& #lifetime mut #field_ty>);
        field.ty = syn::Type::Path(pin_ty);
    };
    // returns the braced parts of a projection pattern and return value.
    // e.g. {foo, bar, ..}, {foo: Pin::new_unchecked(foo), bar:
    // Pin::new_unchecked(bar)}
    let pat_project = |fields: &mut syn::Fields| {
        let mut pat = quote! {};
        let mut project = quote! {};
        for (i, field) in fields.iter_mut().enumerate() {
            // TODO(jeanpierreda): check attributes for e.g. #[unpin]
            field.attrs.clear();
            let lhs;
            let rhs;
            if let Some(ident) = &field.ident {
                lhs = quote! {#ident};
                rhs = ident.clone();
                pat.extend(quote! {#lhs,});
            } else {
                lhs = proc_macro2::Literal::usize_unsuffixed(i).into_token_stream();
                rhs = Ident::new(&format!("item_{i}"), Span::call_site());
                pat.extend(quote! {#lhs: #rhs,});
            }
            project.extend(quote! {#lhs: ::core::pin::Pin::new_unchecked(#rhs),});
        }
        // Also ignore the __must_use_ctor_to_initialize field, if present.
        pat.extend(quote! {..});
        (quote! {{#pat}}, quote! {{#project}})
    };
    let project_body;
    let input_ident = &input.ident;
    let projected_ident = &projected.ident;
    match &mut projected.data {
        syn::Data::Struct(data) => {
            for field in &mut data.fields {
                project_field(field);
            }
            let (pat, project) = pat_project(&mut data.fields);
            project_body = quote! {
                let #input_ident #pat = from;
                #projected_ident #project
            };
        }
        syn::Data::Enum(e) => {
            let mut match_body = quote! {};
            for variant in &mut e.variants {
                for field in &mut variant.fields {
                    project_field(field);
                }
                let (pat, project) = pat_project(&mut variant.fields);
                let variant_ident = &variant.ident;
                match_body.extend(quote! {
                    #input_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
                });
            }
            project_body = quote! {
                match from  {
                    #match_body
                }
            };
        }
        syn::Data::Union(_) => {
            unreachable!("project_pin_impl should early return when it finds a union")
        }
    }

    let (input_impl_generics, input_ty_generics, input_where_clause) =
        input.generics.split_for_impl();
    let (_, projected_generics, _) = projected.generics.split_for_impl();

    Ok(quote! {
        #projected

        impl #input_impl_generics #input_ident #input_ty_generics #input_where_clause {
            #[must_use]
            pub fn project_pin<#lifetime>(self: ::core::pin::Pin<& #lifetime mut Self>) -> #projected_ident #projected_generics {
                unsafe {
                    let from = ::core::pin::Pin::into_inner_unchecked(self);
                    #project_body
                }
            }
        }
    })
}

/// Adds a new lifetime to `generics`, returning the quoted lifetime name.
fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::TokenStream {
    let taken_lifetimes: BTreeSet<&syn::Lifetime> =
        generics.lifetimes().map(|def| &def.lifetime).collect();
    let mut name = Cow::Borrowed(prefix);
    let mut i = 1;
    let lifetime = loop {
        let lifetime = syn::Lifetime::new(&name, Span::call_site());
        if !taken_lifetimes.contains(&lifetime) {
            break lifetime;
        }

        i += 1;
        name = Cow::Owned(format!("{prefix}_{i}"));
    };
    let quoted_lifetime = quote! {#lifetime};
    generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeDef::new(lifetime)));
    quoted_lifetime
}

#[derive(Default)]
struct RecursivelyPinnedArgs {
    is_pinned_drop: bool,
}

impl Parse for RecursivelyPinnedArgs {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let args = <syn::punctuated::Punctuated<Ident, Token![,]>>::parse_terminated(input)?;
        if args.len() > 1 {
            return Err(syn::Error::new(
                input.span(), // not args.span(), as that is only for the first argument.
                &format!("expected at most 1 argument, got: {}", args.len()),
            ));
        }
        let is_pinned_drop = if let Some(arg) = args.first() {
            if arg != "PinnedDrop" {
                return Err(syn::Error::new(
                    arg.span(),
                    "unexpected argument (wasn't `PinnedDrop`)",
                ));
            }
            true
        } else {
            false
        };
        Ok(RecursivelyPinnedArgs { is_pinned_drop })
    }
}

/// Prevents this type from being directly created outside of this crate in safe
/// code.
///
/// For enums and unit structs, this uses the `#[non_exhaustive]` attribute.
/// This leads to unfortunate error messages, but there is no other way to
/// prevent creation of an enum or a unit struct at this time.
///
/// For tuple structs, we also use `#[non_exhaustive]`, as it's no worse than
/// the alternative. Both adding a private field and adding `#[non_exhaustive]`
/// lead to indirect error messages, but `#[non_exhaustive]` is the more likely
/// of the two to ever get custom error message support.
///
/// Finally, for structs with named fields, we actually *cannot* use
/// `#[non_exhaustive]`, because it would make the struct not FFI-safe, and
/// structs with named fields are specifically supported for C++ interop.
/// Instead, we use a private field with a name that indicates the error.
/// (`__must_use_ctor_to_initialize`).
///
/// Unions are not yet implemented properly.
///
/// ---
///
/// Note that the use of `#[non_exhaustive]` also has other effects. At the
/// least: tuple variants and tuple structs marked with `#[non_exhaustive]`
/// cannot be pattern matched using the "normal" syntax. Instead, one must use
/// curly braces. (Broken: `T(x, ..)`; woken: `T{0: x, ..}`).
///
/// (This does not seem very intentional, and with all luck will be fixed before
/// too long.)
fn forbid_initialization(s: &mut syn::DeriveInput) {
    let non_exhaustive_attr = syn::parse_quote!(#[non_exhaustive]);
    match &mut s.data {
        // TODO(b/232969667): prevent creation of unions from safe code.
        // (E.g. hide inside a struct.)
        syn::Data::Union(_) => return,
        syn::Data::Struct(data) => {
            match &mut data.fields {
                syn::Fields::Unit | syn::Fields::Unnamed(_) => {
                    s.attrs.insert(0, non_exhaustive_attr);
                }
                syn::Fields::Named(fields) => {
                    fields.named.push(syn::Field {
                        attrs: vec![],
                        vis: syn::Visibility::Inherited,
                        // TODO(jeanpierreda): better hygiene: work even if a field has the same name.
                        ident: Some(Ident::new(FIELD_FOR_MUST_USE_CTOR, Span::call_site())),
                        colon_token: Some(<syn::Token![:]>::default()),
                        ty: syn::parse_quote!([u8; 0]),
                    });
                }
            }
        }
        syn::Data::Enum(e) => {
            // Enums can't have private fields. Instead, we need to add #[non_exhaustive] to
            // every variant -- this makes it impossible to construct the
            // variants.
            for variant in &mut e.variants {
                variant.attrs.insert(0, non_exhaustive_attr.clone());
            }
        }
    }
}

/// `#[recursively_pinned]` pins every field, similar to `#[pin_project]`, and
/// marks the struct `!Unpin`.
///
/// Example:
///
/// ```
/// #[recursively_pinned]
/// struct S {
///   field: i32,
/// }
/// ```
///
/// This is analogous to using pin_project, pinning every field, as so:
///
/// ```
/// #[pin_project(!Unpin)]
/// struct S {
///   #[pin]
///   field: i32,
/// }
/// ```
///
/// ## Arguments
///
/// ### `PinnedDrop`
///
/// To define a destructor for a recursively-pinned struct, pass `PinnedDrop`
/// and implement the `PinnedDrop` trait.
///
/// `#[recursively_pinned]` prohibits implementing `Drop`, as that would make it
/// easy to violate the `Pin` guarantee. Instead, to define a destructor, one
/// must define a `PinnedDrop` impl, as so:
///
/// ```
/// #[recursively_pinned(PinnedDrop)]
/// struct S {
///   field: i32,
/// }
///
/// impl PinnedDrop for S {
///   unsafe fn pinned_drop(self: Pin<&mut Self>) {
///     println!("I am being destroyed!");
///   }
/// }
/// ```
///
/// (This is analogous to `#[pin_project(PinnedDrop)]`.)
///
/// ## Direct initialization
///
/// Use the `ctor!` macro to instantiate recursively pinned types. For example:
///
/// ```
/// // equivalent to `let x = Point {x: 3, y: 4}`, but uses pinned construction.
/// emplace! {
///   let x = ctor!(Point {x: 3, y: 4});
/// }
/// ```
///
/// Recursively pinned types cannot be created directly in safe code, as they
/// are pinned from the very moment of their creation.
///
/// This is prevented either using `#[non_exhaustive]` or using a private field,
/// depending on the type in question. For example, enums use
/// `#[non_exhaustive]`, and structs with named fields use a private field named
/// `__must_use_ctor_to_initialize`. This can lead to confusing error messages,
/// so watch out!
///
/// ## Supported types
///
/// Structs, enums, and unions are all supported. However, unions do not receive
/// a `pin_project` method, as there is no way to implement pin projection for
/// unions. (One cannot know which field is active.)
#[proc_macro_attribute]
pub fn recursively_pinned(args: TokenStream, item: TokenStream) -> TokenStream {
    match recursively_pinned_impl(args.into(), item.into()) {
        Ok(t) => t.into(),
        Err(e) => e.into_compile_error().into(),
    }
}

/// A separate function for calling from tests.
///
/// See e.g. https://users.rust-lang.org/t/procedural-macro-api-is-used-outside-of-a-procedural-macro/30841
fn recursively_pinned_impl(
    args: proc_macro2::TokenStream,
    item: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
    let args = syn::parse2::<RecursivelyPinnedArgs>(args)?;
    let mut input = syn::parse2::<syn::DeriveInput>(item)?;

    let project_pin_impl = project_pin_impl(&input)?;
    let name = input.ident.clone();

    // Create two copies of input: one (public) has a private field that can't be
    // instantiated. The other (only visible via
    // RecursivelyPinned::CtorInitializedFields) doesn't have this field.
    // This causes `ctor!(Foo {})` to work, but `Foo{}` to complain of a missing
    // field.
    let mut ctor_initialized_input = input.clone();
    // Removing repr(C) triggers dead-code detection.
    ctor_initialized_input.attrs = vec![syn::parse_quote!(#[allow(dead_code)])];
    // TODO(jeanpierreda): This should really check for name collisions with any types
    // used in the fields. Collisions with other names don't matter, because the
    // type is locally defined within a narrow scope.
    ctor_initialized_input.ident = syn::Ident::new(&format!("__CrubitCtor{name}"), name.span());
    let ctor_initialized_name = &ctor_initialized_input.ident;
    forbid_initialization(&mut input);

    let (input_impl_generics, input_ty_generics, input_where_clause) =
        input.generics.split_for_impl();

    let drop_impl = if args.is_pinned_drop {
        quote! {
            impl #input_impl_generics Drop for #name #input_ty_generics #input_where_clause {
                fn drop(&mut self) {
                    unsafe {::ctor::PinnedDrop::pinned_drop(::core::pin::Pin::new_unchecked(self))}
                }
            }
        }
    } else {
        quote! {
            impl #input_impl_generics ::ctor::macro_internal::DoNotImplDrop for #name #input_ty_generics #input_where_clause {}
            /// A no-op PinnedDrop that will cause an error if the user also defines PinnedDrop,
            /// due to forgetting to pass `PinnedDrop` to #[recursively_pinned(PinnedDrop)]`.
            impl #input_impl_generics ::ctor::PinnedDrop for #name #input_ty_generics #input_where_clause {
                unsafe fn pinned_drop(self: ::core::pin::Pin<&mut Self>) {}
            }
        }
    };

    Ok(quote! {
        #input
        #project_pin_impl

        #drop_impl
        impl #input_impl_generics !Unpin for #name #input_ty_generics #input_where_clause {}

        // Introduce a new scope to limit the blast radius of the CtorInitializedFields type.
        // This lets us use relatively readable names: while the impl is visible outside the scope,
        // type is otherwise not visible.
        const _ : () = {
            #ctor_initialized_input

            unsafe impl #input_impl_generics ::ctor::RecursivelyPinned for #name #input_ty_generics #input_where_clause {
                type CtorInitializedFields = #ctor_initialized_name #input_ty_generics;
            }
        };
    })
}

#[cfg(test)]
mod test {
    use super::*;
    use googletest::prelude::*;
    use token_stream_matchers::assert_rs_matches;

    /// Essentially a change detector, but handy for debugging.
    ///
    /// At time of writing, we can't write negative compilation tests, so
    /// asserting on the output is as close as we can get. Once negative
    /// compilation tests are added, it would be better to test various
    /// safety features that way.
    #[gtest]
    fn test_recursively_pinned_struct() {
        let definition =
            recursively_pinned_impl(quote! {}, quote! {#[repr(C)] struct S {x: i32}}).unwrap();

        // The struct can't be directly created, but can be created via
        // CtorInitializedFields:
        assert_rs_matches!(
            definition,
            quote! {
                #[repr(C)]
                struct S {
                    x: i32,
                    __must_use_ctor_to_initialize: [u8; 0]
                }
            }
        );
        assert_rs_matches!(
            definition,
            quote! {
                const _: () = {
                   #[allow(dead_code)]
                   struct __CrubitCtorS {x: i32}
                   unsafe impl ::ctor::RecursivelyPinned for S {
                       type CtorInitializedFields = __CrubitCtorS;
                   }
                };
            }
        );

        // The type is non-Unpin:
        assert_rs_matches!(
            definition,
            quote! {
                impl !Unpin for S {}
            }
        );

        // The remaining features of the generated output are better tested via
        // real tests that exercise the code.
    }

    /// The enum version of `test_recursively_pinned_struct`.
    #[gtest]
    fn test_recursively_pinned_enum() {
        let definition = recursively_pinned_impl(
            quote! {},
            quote! {
                #[repr(C)]
                enum E {
                    A,
                    B(i32),
                }
            },
        )
        .unwrap();

        // The enum variants can't be directly created, but can be created via
        // CtorInitializedFields:
        assert_rs_matches!(
            definition,
            quote! {
                #[repr(C)]
                enum E {
                    #[non_exhaustive]
                    A,
                    #[non_exhaustive]
                    B(i32),
                }
            }
        );
        assert_rs_matches!(
            definition,
            quote! {
                const _: () = {
                    #[allow(dead_code)]
                    enum __CrubitCtorE {
                        A,
                        B(i32),
                    }
                    unsafe impl ::ctor::RecursivelyPinned for E {
                        type CtorInitializedFields = __CrubitCtorE;
                    }
                };
            }
        );

        // The type is non-Unpin:
        assert_rs_matches!(
            definition,
            quote! {
                impl !Unpin for E {}
            }
        );

        // The remaining features of the generated output are better tested via
        // real tests that exercise the code.
    }
}
