diff options
Diffstat (limited to 'engine-macros/src/reflection/enum_impl.rs')
| -rw-r--r-- | engine-macros/src/reflection/enum_impl.rs | 448 |
1 files changed, 448 insertions, 0 deletions
diff --git a/engine-macros/src/reflection/enum_impl.rs b/engine-macros/src/reflection/enum_impl.rs new file mode 100644 index 0000000..14142bf --- /dev/null +++ b/engine-macros/src/reflection/enum_impl.rs @@ -0,0 +1,448 @@ +use quote::{format_ident, quote, ToTokens}; + +use crate::reflection::field::{generate as generate_field, ReflectionFieldGenOptions}; +use crate::util::{find_engine_crate_path, syn_path, SynPathExt}; + +const PRIMITIVE_REPRS: &[&'static str] = &[ + "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", + "isize", +]; + +pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream +{ + let engine_crate_path = find_engine_crate_path().unwrap(); + + let input_ident = &input.ident; + + let mut generics = input.generics; + + for type_param in generics.type_params_mut() { + type_param + .bounds + .push(syn::TypeParamBound::Trait(syn::TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: engine_crate_path.join(syn_path!(reflection::Reflection)), + })); + } + + let (impl_generics, type_generics, where_clause) = generics.split_for_impl(); + + let variant_lookup_match_arms = + input.variants.iter().enumerate().map(|(index, variant)| { + let variant_ident = &variant.ident; + + let pattern = match variant.fields { + syn::Fields::Unit => quote! { Self::#variant_ident }, + syn::Fields::Named(_) => quote! { Self::#variant_ident { .. } }, + syn::Fields::Unnamed(_) => quote! { Self::#variant_ident(..) }, + }; + + quote! { + #pattern => &enum_reflection.variants[#index] + } + }); + + let is_unit_only = input + .variants + .iter() + .all(|variant| matches!(variant.fields, syn::Fields::Unit)); + + let mod_name = format_ident!("__engine_private_{input_ident}"); + + let mut has_c_repr = false; + + let mut primitive_repr: Option<proc_macro2::Ident> = None; + + for attr in &input.attrs { + let syn::Meta::List(attr_meta) = &attr.meta else { + continue; + }; + + if !attr_meta.path.is_ident("repr") { + continue; + } + + attr_meta + .parse_nested_meta(|nested_meta| { + let Some(meta_ident) = nested_meta.path.get_ident() else { + return Ok(()); + }; + + if meta_ident == "C" { + has_c_repr = true; + } else if PRIMITIVE_REPRS.contains(&meta_ident.to_string().as_str()) { + primitive_repr = Some(meta_ident.clone()); + } else { + return Err(nested_meta.error(concat!( + "Unsupported representation. ", + "Only C and primitive representations are allowed" + ))); + } + + Ok(()) + }) + .unwrap(); + } + + if !has_c_repr && primitive_repr.is_none() { + panic!("Enums must have a C or primitive representation to derive Reflection"); + } + + let mod_content = if is_unit_only { + quote! {} + } else { + let variant_fields_structs = + input.variants.iter().map(|variant| { + let fields = variant.fields.iter().map(|field| { + let field_type = &field.ty; + + if let Some(field_ident) = &field.ident { + quote! { pub #field_ident: #field_type } + } else { + quote! { pub #field_type } + } + }); + + let ident = format_ident!("VariantFields{}", variant.ident); + + let generics_phantom_data_elems = generics.params.iter().filter_map( + |generic_param| match generic_param { + syn::GenericParam::Type(type_param) => { + Some(type_param.ident.clone().into_token_stream()) + } + syn::GenericParam::Lifetime(_) => { + unimplemented!(); + } + syn::GenericParam::Const(_const_param) => None, + }, + ); + + let generics_phantom_data = quote! { + std::marker::PhantomData<(#(#generics_phantom_data_elems),*)> + }; + + match variant.fields { + syn::Fields::Named(_) => quote! { + #[repr(C)] + pub struct #ident #impl_generics + #where_clause + { + #(#fields,)* + _pd: #generics_phantom_data + } + }, + syn::Fields::Unnamed(_) => quote! { + #[repr(C)] + pub struct #ident #impl_generics ( + #(#fields,)* + #generics_phantom_data + ) + #where_clause; + }, + syn::Fields::Unit => quote! { + #[repr(C)] + pub struct #ident #impl_generics (#generics_phantom_data) + #where_clause; + }, + } + }); + + let fields_union_fields = input.variants.iter().map(|variant| { + let variant_ident = &variant.ident; + + let variant_fields_struct_ident = + format_ident!("VariantFields{}", variant.ident); + + quote! { + pub #variant_ident: + std::mem::ManuallyDrop<#variant_fields_struct_ident #type_generics> + } + }); + + let discriminant_enum_variants = input.variants.iter().map(|variant| { + let variant_ident = &variant.ident; + + if let Some((_, discriminant)) = &variant.discriminant { + quote! { #variant_ident = #discriminant } + } else { + quote! { #variant_ident } + } + }); + + // If the enum has both a C & primitive repr, the primitive repr must be used + let discriminant_enum_repr = if let Some(primitive_repr) = primitive_repr { + quote! { #primitive_repr } + } else if has_c_repr { + quote! { C } + } else { + unreachable!(); + }; + + quote! { + #![allow(non_snake_case, dead_code)] + + use super::*; + + #[repr(C)] + pub struct Equivalent #impl_generics + #where_clause + { + pub tag: Discriminant, + pub payload: Fields #type_generics + } + + #[repr(#discriminant_enum_repr)] + pub enum Discriminant + { + #(#discriminant_enum_variants),* + } + + #[repr(C)] + pub union Fields #impl_generics + #where_clause + { + #(#fields_union_fields),* + } + + #(#variant_fields_structs)* + } + }; + + let variants = generate_variants( + input_ident, + &input.variants, + &generics, + is_unit_only, + &engine_crate_path, + ); + + let discriminant_layout = if is_unit_only { + quote! { + std::alloc::Layout::new::<Self>() + } + } else { + quote! { + std::alloc::Layout::new::<#mod_name::Discriminant>() + } + }; + + let fields_layout = if is_unit_only { + quote! { + None + } + } else { + quote! { + Some(std::alloc::Layout::new::<#mod_name::Fields>()) + } + }; + + quote! { + #[doc(hidden)] + mod #mod_name { + #mod_content + } + + unsafe impl #impl_generics #engine_crate_path::reflection::Reflection for + #input_ident #type_generics #where_clause + { + const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = + &const { + #engine_crate_path::reflection::Type::Enum( + #engine_crate_path::reflection::Enum { + variants: &[#(#variants),*], + is_unit_only: #is_unit_only, + discriminant_layout: #discriminant_layout, + fields_layout: #fields_layout + } + ) + }; + } + + unsafe impl #impl_generics #engine_crate_path::reflection::EnumReflectionExt for + #input_ident #type_generics #where_clause + { + fn get_variant_reflection(&self) + -> &'static #engine_crate_path::reflection::EnumVariant + { + let enum_reflection = unsafe { + <Self as #engine_crate_path::reflection::Reflection>::TYPE_REFLECTION + .as_enum() + .unwrap_unchecked() + }; + + match self { + #(#variant_lookup_match_arms),* + } + } + } + } +} + +fn generate_variants<'a>( + input_ident: &proc_macro2::Ident, + input_variants: &'a syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>, + input_generics: &'a syn::Generics, + is_unit_only: bool, + engine_crate_path: &'a syn::Path, +) -> impl Iterator<Item = proc_macro2::TokenStream> + use<'a> +{ + let mod_name = format_ident!("__engine_private_{input_ident}"); + + let (_, type_generics, _) = input_generics.split_for_impl(); + + input_variants.iter().map(move |variant| { + let variant_name = + syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); + + let variant_ident = &variant.ident; + + let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); + + let fields = match &variant.fields { + syn::Fields::Unit => quote! { None }, + syn::Fields::Named(named_fields) => { + let fields = named_fields.named.iter().enumerate().map( + |(variant_field_index, variant_field)| { + generate_field( + variant_field, + variant_field_index, + &engine_crate_path, + ReflectionFieldGenOptions { + // enum variant fields are always public + field_vis_override: Some(syn::Visibility::Public( + <syn::Token![pub]>::default(), + )), + gen_get_byte_offset: &|field| { + if let Some(field_ident) = &field.ident { + quote! { + std::mem::offset_of!( + #mod_name::Equivalent #type_generics, + payload.#variant_ident + ) + + std::mem::offset_of!( + #mod_name::#variant_fields_struct_ident + #type_generics, + #field_ident + ) + } + } else { + unreachable!(); + } + }, + type_reflection_optional: input_generics + .params + .is_empty(), + }, + ) + }, + ); + + quote! { + Some(#engine_crate_path::reflection::EnumVariantFields::Named { + fields: &[#(#fields),*] + }) + } + } + syn::Fields::Unnamed(unnamed_fields) => { + let fields = unnamed_fields.unnamed.iter().enumerate().map( + |(variant_field_index, variant_field)| { + generate_field( + variant_field, + variant_field_index, + &engine_crate_path, + ReflectionFieldGenOptions { + // enum variant fields are always public + field_vis_override: Some(syn::Visibility::Public( + <syn::Token![pub]>::default(), + )), + gen_get_byte_offset: &|field| { + if let Some(_) = &field.ident { + unreachable!() + } else { + let field_index = + proc_macro2::Literal::usize_unsuffixed( + variant_field_index, + ); + + quote! { + std::mem::offset_of!( + #mod_name::Equivalent #type_generics, + payload.#variant_ident + ) + + std::mem::offset_of!( + #mod_name::#variant_fields_struct_ident + #type_generics, + #field_index + ) + } + } + }, + type_reflection_optional: input_generics + .params + .is_empty(), + }, + ) + }, + ); + + quote! { + Some(#engine_crate_path::reflection::EnumVariantFields::Unnamed { + fields: &[#(#fields),*] + }) + } + } + }; + + let discriminant = match variant.fields { + syn::Fields::Unit if is_unit_only => { + quote! { + let mut buf = [0u8; std::mem::size_of::<i128>()]; + + // Self or any of it's type parameters might have a Drop impl and + // since that can not be const evaluated, the discriminant value has + // to be wrapped in ManuallyDrop + let discriminant = std::mem::ManuallyDrop::new(Self::#variant_ident); + + unsafe { + std::ptr::copy_nonoverlapping( + (&raw const discriminant).cast::<u8>(), + buf.as_mut_ptr(), + std::mem::size_of::<Self>(), + ); + } + + buf + } + } + _ => { + quote! { + let mut buf = [0u8; std::mem::size_of::<i128>()]; + + let discriminant = #mod_name::Discriminant::#variant_ident; + + unsafe { + std::ptr::copy_nonoverlapping( + (&raw const discriminant).cast::<u8>(), + buf.as_mut_ptr(), + std::mem::size_of::<#mod_name::Discriminant>(), + ); + } + + buf + } + } + }; + + quote! { + #engine_crate_path::reflection::EnumVariant { + name: #variant_name, + fields: #fields, + discriminant: #engine_crate_path::reflection::EnumDiscriminant { + buf: { #discriminant } + } + } + } + }) +} |
