use proc_macro::{TokenStream, TokenTree};
use quote::quote;
use syn::{parse, Ident, ItemEnum, Token};

/// Subtracts two numbers and calls a given callback macro with the result. Optionally, a
/// additional argument (delimeted) can be given which will also be passed to the
/// callback.
///
/// # Input
/// `$num_a - $num_b, $callback $(, $user_data)?`
///
/// # Examples
/// ```
/// # use std::any::TypeId;
/// use util_macros::sub;
///
/// macro_rules! sub_cb {
///     ($num: literal) => {
///         $num
///     };
/// }
///
/// type Foo = [u8; sub!(5 - 2, sub_cb)];
///
/// assert_eq!(TypeId::of::<Foo>(), TypeId::of::<[u8; 3]>());
/// ```
/// <br>
///
/// The callback macro can be called with extra arguments.
/// ```
/// # use std::any::TypeId;
/// use util_macros::sub;
///
/// macro_rules! sub_cb {
///     ($num: literal, $to_multiply: literal) => {
///         $num * $to_multiply
///     };
/// }
///
/// type Foo = [u8; sub!(5 - 2, sub_cb, (20))];
///
/// assert_eq!(TypeId::of::<Foo>(), TypeId::of::<[u8; 60]>());
/// ```
/// <br>
///
/// The callback is called with the identifier `overflow` if a overflow occurs.
/// ```
/// # use std::any::TypeId;
/// use util_macros::sub;
///
/// macro_rules! sub_cb {
///     ($num: literal) => {
///         $num
///     };
///
///     (overflow) => {
///         128
///     };
/// }
///
/// type Foo = [u8; sub!(3 - 10, sub_cb)];
///
/// assert_eq!(TypeId::of::<Foo>(), TypeId::of::<[u8; 128]>());
/// ```
#[proc_macro]
pub fn sub(input: TokenStream) -> TokenStream
{
    let mut input_tt_iter = input.into_iter();

    let num_a = match input_tt_iter.next().unwrap() {
        TokenTree::Literal(lit) => lit.to_string().parse::<u32>().unwrap(),
        _ => {
            panic!("Expected a number literal");
        }
    };

    match input_tt_iter.next().unwrap() {
        TokenTree::Punct(punct) if punct.as_char() == '-' => {}
        _ => {
            panic!("Expected a '-' token");
        }
    };

    let num_b = match input_tt_iter.next().unwrap() {
        TokenTree::Literal(lit) => lit.to_string().parse::<u32>().unwrap(),
        _ => {
            panic!("Expected a number literal");
        }
    };

    match input_tt_iter.next().unwrap() {
        TokenTree::Punct(punct) if punct.as_char() == ',' => {}
        _ => {
            panic!("Expected a ',' token");
        }
    };

    let cb_ident = match input_tt_iter.next().unwrap() {
        TokenTree::Ident(cb_ident) => {
            proc_macro2::Ident::new(&cb_ident.to_string(), cb_ident.span().into())
        }
        _ => {
            panic!("Expected a identifier");
        }
    };

    let opt_user_data = input_tt_iter
        .next()
        .map(|comma_tt| {
            match comma_tt {
                TokenTree::Punct(punct) if punct.as_char() == ',' => {}
                _ => {
                    panic!("Expected a ',' token");
                }
            };

            let user_data_tt = input_tt_iter.next().unwrap();

            let TokenTree::Group(group) = user_data_tt else {
                panic!("User data must be a delimeted")
            };

            let inside: proc_macro2::TokenStream = group.stream().into();

            quote! {, #inside }
        })
        .unwrap_or_default();

    let Some(subtracted) = num_a.checked_sub(num_b) else {
        return quote! {
            #cb_ident!(overflow)
        }
        .into();
    };

    let subtracted_lit = proc_macro2::Literal::u32_unsuffixed(subtracted);

    quote! {
        #cb_ident!(#subtracted_lit #opt_user_data)
    }
    .into()
}

#[proc_macro_derive(FromRepr)]
pub fn from_repr(input: TokenStream) -> TokenStream
{
    let enum_item = parse::<ItemEnum>(input).unwrap();

    let repr_attr = enum_item
        .attrs
        .iter()
        .find(|attr| {
            attr.path()
                .get_ident()
                .is_some_and(|attr_ident| attr_ident == "repr")
        })
        .unwrap();

    let repr = repr_attr.parse_args::<Ident>().unwrap();

    let repr_str = repr.to_string();

    if !((repr_str.starts_with('u') || repr_str.starts_with('i'))
        && repr_str
            .chars()
            .skip(1)
            .all(|character| character.is_ascii_digit()))
    {
        panic!("Invalid repr. Must be u* or i* where * is a number");
    }

    let variants = enum_item.variants.iter().map(|variant| {
        let Some((_, discriminant)) = &variant.discriminant else {
            panic!("All variants must have discriminants");
        };

        (variant.ident.clone(), discriminant.clone())
    });

    let match_arms = variants.map(|(variant, discriminant)| {
        quote! {
            #discriminant => Some(Self::#variant),
        }
    });

    let enum_ident = enum_item.ident.clone();

    quote! {
        impl #enum_ident
        {
            pub fn from_repr(repr: #repr) -> Option<Self>
            {
                match repr {
                    #(#match_arms)*
                    _ => None
                }
            }
        }
    }
    .into()
}

#[proc_macro_derive(VariantArr, attributes(variant_arr))]
pub fn variant_arr(input: TokenStream) -> TokenStream
{
    let enum_item = parse::<ItemEnum>(input).unwrap();

    let arr_ident_attr = enum_item
        .attrs
        .iter()
        .find(|attr| {
            attr.path()
                .get_ident()
                .is_some_and(|attr_ident| attr_ident == "variant_arr")
        })
        .expect("No variant_arr attribute found");

    let mut arr_name = None;

    arr_ident_attr
        .parse_nested_meta(|meta| {
            if meta.path.is_ident("name") {
                meta.input.parse::<Token![=]>()?;

                arr_name = Some(meta.input.parse::<Ident>()?);

                return Ok(());
            }

            Err(meta.error("Unknown field in variant_arr attribute"))
        })
        .unwrap();

    let arr_name = arr_name.expect("Missing field 'name' in variant_arr attribute");

    let variants = enum_item
        .variants
        .iter()
        .map(|variant| variant.ident.clone());

    let enum_ident = enum_item.ident.clone();

    quote! {
        impl #enum_ident
        {
            pub const #arr_name: &[Self] = &[#(Self::#variants,)*];
        }
    }
    .into()
}