diff options
Diffstat (limited to 'engine-macros/src/reflection/enum_impl.rs')
| -rw-r--r-- | engine-macros/src/reflection/enum_impl.rs | 518 |
1 files changed, 518 insertions, 0 deletions
diff --git a/engine-macros/src/reflection/enum_impl.rs b/engine-macros/src/reflection/enum_impl.rs new file mode 100644 index 0000000..0ce2562 --- /dev/null +++ b/engine-macros/src/reflection/enum_impl.rs @@ -0,0 +1,518 @@ +use quote::{format_ident, quote, ToTokens}; +use syn::ItemEnum; + +use crate::reflection::default_value::gen_get_default_value_fn; +use crate::reflection::field::{generate as generate_field, ReflectionFieldGenOptions}; +use crate::reflection::options_attr::OptionsAttr; +use crate::util::find_engine_crate_path; + +pub fn generate(input: syn::ItemEnum, options: OptionsAttr) -> proc_macro2::TokenStream +{ + let engine_crate_path = find_engine_crate_path().unwrap(); + + let variant_lookup_match_arms = input + .variants + .iter() + .enumerate() + .map(|(index, variant)| { + let variant_ident = &variant.ident; + + let pattern = match variant.fields { + syn::Fields::Unit => quote! { Self::#variant_ident }, + syn::Fields::Named(_) => quote! { Self::#variant_ident { .. } }, + syn::Fields::Unnamed(_) => quote! { Self::#variant_ident(..) }, + }; + + quote! { + #pattern => &enum_reflection.variants[#index] + } + }) + .collect::<Vec<_>>(); + + let is_unit_only = input + .variants + .iter() + .all(|variant| matches!(variant.fields, syn::Fields::Unit)); + + let mod_name = format_ident!("__engine_private_{}", input.ident); + + let reprs = get_reprs(&input); + + if !reprs.has_c_repr && reprs.primitive_repr.is_none() { + panic!("Enums must have a C or primitive representation to derive Reflection"); + } + + let mod_content = generate_mod_content(&input, is_unit_only, &reprs); + + let impls: &mut dyn Iterator<Item = proc_macro2::TokenStream> = + if input.generics.params.is_empty() { + &mut [generate_impls( + &input, + None, + is_unit_only, + &variant_lookup_match_arms, + &mod_name, + &engine_crate_path, + )] + .into_iter() + } else { + &mut options.impl_with_generics.iter().map(|generic_args| { + generate_impls( + &input, + Some(generic_args), + is_unit_only, + &variant_lookup_match_arms, + &mod_name, + &engine_crate_path, + ) + }) + }; + + quote! { + #[doc(hidden)] + #[allow(non_snake_case)] + mod #mod_name { + #mod_content + } + + #(#impls)* + } +} + +fn generate_impls( + input: &syn::ItemEnum, + generic_args: Option<&syn::AngleBracketedGenericArguments>, + is_unit_only: bool, + variant_lookup_match_arms: &[proc_macro2::TokenStream], + mod_name: &proc_macro2::Ident, + engine_crate_path: &syn::Path, +) -> proc_macro2::TokenStream +{ + let get_default_value_fn = gen_get_default_value_fn(&input.ident, generic_args); + + let variants = generate_variants( + &input.ident, + &input.variants, + generic_args, + is_unit_only, + &engine_crate_path, + ); + + let tagged_union = if is_unit_only { + quote! { + None + } + } else { + quote! { + Some(#engine_crate_path::reflection::EnumTaggedUnion { + discriminant_layout: std::alloc::Layout::new::<#mod_name::Discriminant>(), + discriminant_byte_offset: + std::mem::offset_of!(#mod_name::Equivalent #generic_args, tag), + fields_layout: + std::alloc::Layout::new::<#mod_name::Fields #generic_args>(), + fields_byte_offset: + std::mem::offset_of!(#mod_name::Equivalent #generic_args, payload), + }) + } + }; + + let generics_type_aliases = input + .generics + .type_params() + .zip( + generic_args + .iter() + .map(|generic_args| &generic_args.args) + .flatten() + .flat_map(|generic_arg| match generic_arg { + syn::GenericArgument::Type(ty) => Some(ty), + _ => None, + }), + ) + .map(|(type_param, generic_arg_type)| { + let type_param_ident = &type_param.ident; + + quote! { + type #type_param_ident = #generic_arg_type; + } + }); + + let input_ident = &input.ident; + + quote! { + unsafe impl #engine_crate_path::reflection::Reflection for + #input_ident #generic_args + { + const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = + &const { + #(#generics_type_aliases)* + + #engine_crate_path::reflection::Type::Enum( + #engine_crate_path::reflection::Enum { + variants: &[#(#variants),*], + is_unit_only: #is_unit_only, + tagged_union: #tagged_union, + get_default_value: || { + #get_default_value_fn + } + } + ) + }; + } + + unsafe impl #engine_crate_path::reflection::EnumReflectionExt for + #input_ident #generic_args + { + fn get_variant_reflection(&self) + -> &'static #engine_crate_path::reflection::EnumVariant + { + let enum_reflection = unsafe { + <Self as #engine_crate_path::reflection::Reflection>::TYPE_REFLECTION + .as_enum() + .unwrap_unchecked() + }; + + match self { + #(#variant_lookup_match_arms),* + } + } + } + } +} + +fn generate_variants<'a>( + input_ident: &proc_macro2::Ident, + input_variants: &'a syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>, + generic_args: Option<&'a syn::AngleBracketedGenericArguments>, + is_unit_only: bool, + engine_crate_path: &'a syn::Path, +) -> impl Iterator<Item = proc_macro2::TokenStream> + use<'a> +{ + let mod_name = format_ident!("__engine_private_{input_ident}"); + + input_variants.iter().map(move |variant| { + let variant_name = + syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); + + let variant_ident = &variant.ident; + + let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); + + let fields = match &variant.fields { + syn::Fields::Unit => quote! { None }, + syn::Fields::Named(named_fields) => { + let fields = named_fields.named.iter().enumerate().map( + |(variant_field_index, variant_field)| { + generate_field( + variant_field, + variant_field_index, + &engine_crate_path, + ReflectionFieldGenOptions { + // enum variant fields are always public + field_vis_override: Some(syn::Visibility::Public( + <syn::Token![pub]>::default(), + )), + gen_get_byte_offset: &|field| { + if let Some(field_ident) = &field.ident { + quote! { + std::mem::offset_of!( + #mod_name::Equivalent #generic_args, + payload.#variant_ident + ) + + std::mem::offset_of!( + #mod_name::#variant_fields_struct_ident + #generic_args, + #field_ident + ) + } + } else { + unreachable!(); + } + }, + }, + ) + }, + ); + + quote! { + Some(#engine_crate_path::reflection::EnumVariantFields::Named { + fields: &[#(#fields),*] + }) + } + } + syn::Fields::Unnamed(unnamed_fields) => { + let fields = unnamed_fields.unnamed.iter().enumerate().map( + |(variant_field_index, variant_field)| { + generate_field( + variant_field, + variant_field_index, + &engine_crate_path, + ReflectionFieldGenOptions { + // enum variant fields are always public + field_vis_override: Some(syn::Visibility::Public( + <syn::Token![pub]>::default(), + )), + gen_get_byte_offset: &|field| { + if let Some(_) = &field.ident { + unreachable!() + } else { + let field_index = + proc_macro2::Literal::usize_unsuffixed( + variant_field_index, + ); + + quote! { + std::mem::offset_of!( + #mod_name::Equivalent #generic_args, + payload.#variant_ident + ) + + std::mem::offset_of!( + #mod_name::#variant_fields_struct_ident + #generic_args, + #field_index + ) + } + } + }, + }, + ) + }, + ); + + quote! { + Some(#engine_crate_path::reflection::EnumVariantFields::Unnamed { + fields: &[#(#fields),*] + }) + } + } + }; + + let discriminant = match variant.fields { + syn::Fields::Unit if is_unit_only => { + quote! { + let mut buf = [0u8; std::mem::size_of::<i128>()]; + + // Self or any of it's type parameters might have a Drop impl and + // since that can not be const evaluated, the discriminant value has + // to be wrapped in ManuallyDrop + let discriminant = std::mem::ManuallyDrop::new(Self::#variant_ident); + + unsafe { + std::ptr::copy_nonoverlapping( + (&raw const discriminant).cast::<u8>(), + buf.as_mut_ptr(), + std::mem::size_of::<Self>(), + ); + } + + buf + } + } + _ => { + quote! { + let mut buf = [0u8; std::mem::size_of::<i128>()]; + + let discriminant = #mod_name::Discriminant::#variant_ident; + + unsafe { + std::ptr::copy_nonoverlapping( + (&raw const discriminant).cast::<u8>(), + buf.as_mut_ptr(), + std::mem::size_of::<#mod_name::Discriminant>(), + ); + } + + buf + } + } + }; + + quote! { + #engine_crate_path::reflection::EnumVariant { + name: #variant_name, + fields: #fields, + discriminant: #engine_crate_path::reflection::EnumDiscriminant { + buf: { #discriminant } + } + } + } + }) +} + +fn generate_mod_content( + input: &syn::ItemEnum, + is_unit_only: bool, + reprs: &Reprs, +) -> proc_macro2::TokenStream +{ + if is_unit_only { + return quote! {}; + } + + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); + + let variant_fields_structs = input.variants.iter().map(|variant| { + let fields = variant.fields.iter().map(|field| { + let field_type = &field.ty; + + if let Some(field_ident) = &field.ident { + quote! { pub #field_ident: #field_type } + } else { + quote! { pub #field_type } + } + }); + + let ident = format_ident!("VariantFields{}", variant.ident); + + let generics_phantom_data_elems = input.generics.params.iter().filter_map( + |generic_param| match generic_param { + syn::GenericParam::Type(type_param) => { + Some(type_param.ident.clone().into_token_stream()) + } + syn::GenericParam::Lifetime(_) => { + unimplemented!(); + } + syn::GenericParam::Const(_const_param) => None, + }, + ); + + let generics_phantom_data = quote! { + std::marker::PhantomData<(#(#generics_phantom_data_elems),*)> + }; + + match variant.fields { + syn::Fields::Named(_) => quote! { + #[repr(C)] + pub struct #ident #impl_generics + #where_clause + { + #(#fields,)* + _pd: #generics_phantom_data + } + }, + syn::Fields::Unnamed(_) => quote! { + #[repr(C)] + pub struct #ident #impl_generics ( + #(#fields,)* + #generics_phantom_data + ) + #where_clause; + }, + syn::Fields::Unit => quote! { + #[repr(C)] + pub struct #ident #impl_generics (#generics_phantom_data) + #where_clause; + }, + } + }); + + let fields_union_fields = input.variants.iter().map(|variant| { + let variant_ident = &variant.ident; + + let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); + + quote! { + pub #variant_ident: + std::mem::ManuallyDrop<#variant_fields_struct_ident #type_generics> + } + }); + + let discriminant_enum_variants = input.variants.iter().map(|variant| { + let variant_ident = &variant.ident; + + if let Some((_, discriminant)) = &variant.discriminant { + quote! { #variant_ident = #discriminant } + } else { + quote! { #variant_ident } + } + }); + + // If the enum has both a C & primitive repr, the primitive repr must be used + let discriminant_enum_repr = if let Some(primitive_repr) = &reprs.primitive_repr { + quote! { #primitive_repr } + } else if reprs.has_c_repr { + quote! { C } + } else { + unreachable!(); + }; + + quote! { + #![allow(non_snake_case, dead_code)] + + use super::*; + + #[repr(C)] + pub struct Equivalent #impl_generics + #where_clause + { + pub tag: Discriminant, + pub payload: Fields #type_generics + } + + #[repr(#discriminant_enum_repr)] + pub enum Discriminant + { + #(#discriminant_enum_variants),* + } + + #[repr(C)] + pub union Fields #impl_generics + #where_clause + { + #(#fields_union_fields),* + } + + #(#variant_fields_structs)* + } +} + +fn get_reprs(input: &ItemEnum) -> Reprs +{ + let mut has_c_repr = false; + let mut primitive_repr: Option<proc_macro2::Ident> = None; + + for attr in &input.attrs { + let syn::Meta::List(attr_meta) = &attr.meta else { + continue; + }; + + if !attr_meta.path.is_ident("repr") { + continue; + } + + attr_meta + .parse_nested_meta(|nested_meta| { + let Some(meta_ident) = nested_meta.path.get_ident() else { + return Ok(()); + }; + + if meta_ident == "C" { + has_c_repr = true; + } else if PRIMITIVE_REPRS.contains(&meta_ident.to_string().as_str()) { + primitive_repr = Some(meta_ident.clone()); + } else { + return Err(nested_meta.error(concat!( + "Unsupported representation. ", + "Only C and primitive representations are allowed" + ))); + } + + Ok(()) + }) + .unwrap(); + } + + Reprs { has_c_repr, primitive_repr } +} + +#[derive(Debug)] +struct Reprs +{ + has_c_repr: bool, + primitive_repr: Option<proc_macro2::Ident>, +} + +const PRIMITIVE_REPRS: &[&'static str] = &[ + "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", + "isize", +]; |
