use quote::{format_ident, quote, ToTokens}; use crate::reflection::field::{generate as generate_field, ReflectionFieldGenOptions}; use crate::util::{find_engine_crate_path, syn_path, SynPathExt}; const PRIMITIVE_REPRS: &[&'static str] = &[ "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize", ]; pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream { let engine_crate_path = find_engine_crate_path().unwrap(); let input_ident = &input.ident; let mut generics = input.generics; for type_param in generics.type_params_mut() { type_param .bounds .push(syn::TypeParamBound::Trait(syn::TraitBound { paren_token: None, modifier: syn::TraitBoundModifier::None, lifetimes: None, path: engine_crate_path.join(syn_path!(reflection::Reflection)), })); } let (impl_generics, type_generics, where_clause) = generics.split_for_impl(); let variant_lookup_match_arms = input.variants.iter().enumerate().map(|(index, variant)| { let variant_ident = &variant.ident; let pattern = match variant.fields { syn::Fields::Unit => quote! { Self::#variant_ident }, syn::Fields::Named(_) => quote! { Self::#variant_ident { .. } }, syn::Fields::Unnamed(_) => quote! { Self::#variant_ident(..) }, }; quote! { #pattern => &enum_reflection.variants[#index] } }); let is_unit_only = input .variants .iter() .all(|variant| matches!(variant.fields, syn::Fields::Unit)); let mod_name = format_ident!("__engine_private_{input_ident}"); let mut has_c_repr = false; let mut primitive_repr: Option = None; for attr in &input.attrs { let syn::Meta::List(attr_meta) = &attr.meta else { continue; }; if !attr_meta.path.is_ident("repr") { continue; } attr_meta .parse_nested_meta(|nested_meta| { let Some(meta_ident) = nested_meta.path.get_ident() else { return Ok(()); }; if meta_ident == "C" { has_c_repr = true; } else if PRIMITIVE_REPRS.contains(&meta_ident.to_string().as_str()) { primitive_repr = Some(meta_ident.clone()); } else { return Err(nested_meta.error(concat!( "Unsupported representation. ", "Only C and primitive representations are allowed" ))); } Ok(()) }) .unwrap(); } if !has_c_repr && primitive_repr.is_none() { panic!("Enums must have a C or primitive representation to derive Reflection"); } let mod_content = if is_unit_only { quote! {} } else { let variant_fields_structs = input.variants.iter().map(|variant| { let fields = variant.fields.iter().map(|field| { let field_type = &field.ty; if let Some(field_ident) = &field.ident { quote! { pub #field_ident: #field_type } } else { quote! { pub #field_type } } }); let ident = format_ident!("VariantFields{}", variant.ident); let generics_phantom_data_elems = generics.params.iter().filter_map( |generic_param| match generic_param { syn::GenericParam::Type(type_param) => { Some(type_param.ident.clone().into_token_stream()) } syn::GenericParam::Lifetime(_) => { unimplemented!(); } syn::GenericParam::Const(_const_param) => None, }, ); let generics_phantom_data = quote! { std::marker::PhantomData<(#(#generics_phantom_data_elems),*)> }; match variant.fields { syn::Fields::Named(_) => quote! { #[repr(C)] pub struct #ident #impl_generics #where_clause { #(#fields,)* _pd: #generics_phantom_data } }, syn::Fields::Unnamed(_) => quote! { #[repr(C)] pub struct #ident #impl_generics ( #(#fields,)* #generics_phantom_data ) #where_clause; }, syn::Fields::Unit => quote! { #[repr(C)] pub struct #ident #impl_generics (#generics_phantom_data) #where_clause; }, } }); let fields_union_fields = input.variants.iter().map(|variant| { let variant_ident = &variant.ident; let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); quote! { pub #variant_ident: std::mem::ManuallyDrop<#variant_fields_struct_ident #type_generics> } }); let discriminant_enum_variants = input.variants.iter().map(|variant| { let variant_ident = &variant.ident; if let Some((_, discriminant)) = &variant.discriminant { quote! { #variant_ident = #discriminant } } else { quote! { #variant_ident } } }); // If the enum has both a C & primitive repr, the primitive repr must be used let discriminant_enum_repr = if let Some(primitive_repr) = primitive_repr { quote! { #primitive_repr } } else if has_c_repr { quote! { C } } else { unreachable!(); }; quote! { #![allow(non_snake_case, dead_code)] use super::*; #[repr(C)] pub struct Equivalent #impl_generics #where_clause { pub tag: Discriminant, pub payload: Fields #type_generics } #[repr(#discriminant_enum_repr)] pub enum Discriminant { #(#discriminant_enum_variants),* } #[repr(C)] pub union Fields #impl_generics #where_clause { #(#fields_union_fields),* } #(#variant_fields_structs)* } }; let variants = generate_variants( input_ident, &input.variants, &generics, is_unit_only, &engine_crate_path, ); quote! { #[doc(hidden)] mod #mod_name { #mod_content } unsafe impl #impl_generics #engine_crate_path::reflection::Reflection for #input_ident #type_generics #where_clause { const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = &const { #engine_crate_path::reflection::Type::Enum( #engine_crate_path::reflection::Enum { variants: &[#(#variants),*], is_unit_only: #is_unit_only } ) }; } unsafe impl #impl_generics #engine_crate_path::reflection::EnumReflectionExt for #input_ident #type_generics #where_clause { fn get_variant_reflection(&self) -> &'static #engine_crate_path::reflection::EnumVariant { let enum_reflection = unsafe { ::TYPE_REFLECTION .as_enum() .unwrap_unchecked() }; match self { #(#variant_lookup_match_arms),* } } } } } fn generate_variants<'a>( input_ident: &proc_macro2::Ident, input_variants: &'a syn::punctuated::Punctuated, input_generics: &'a syn::Generics, is_unit_only: bool, engine_crate_path: &'a syn::Path, ) -> impl Iterator + use<'a> { let mod_name = format_ident!("__engine_private_{input_ident}"); let (_, type_generics, _) = input_generics.split_for_impl(); input_variants.iter().map(move |variant| { let variant_name = syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); let variant_ident = &variant.ident; let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); let fields = match &variant.fields { syn::Fields::Unit => quote! { None }, syn::Fields::Named(named_fields) => { let fields = named_fields.named.iter().enumerate().map( |(variant_field_index, variant_field)| { generate_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { // enum variant fields are always public field_vis_override: Some(syn::Visibility::Public( ::default(), )), gen_get_byte_offset: &|field| { if let Some(field_ident) = &field.ident { quote! { std::mem::offset_of!( #mod_name::Equivalent #type_generics, payload.#variant_ident ) + std::mem::offset_of!( #mod_name::#variant_fields_struct_ident #type_generics, #field_ident ) } } else { unreachable!(); } }, type_reflection_optional: input_generics .params .is_empty(), }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Named { fields: &[#(#fields),*] }) } } syn::Fields::Unnamed(unnamed_fields) => { let fields = unnamed_fields.unnamed.iter().enumerate().map( |(variant_field_index, variant_field)| { generate_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { // enum variant fields are always public field_vis_override: Some(syn::Visibility::Public( ::default(), )), gen_get_byte_offset: &|field| { if let Some(_) = &field.ident { unreachable!() } else { let field_index = proc_macro2::Literal::usize_unsuffixed( variant_field_index, ); quote! { std::mem::offset_of!( #mod_name::Equivalent #type_generics, payload.#variant_ident ) + std::mem::offset_of!( #mod_name::#variant_fields_struct_ident #type_generics, #field_index ) } } }, type_reflection_optional: input_generics .params .is_empty(), }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Unnamed { fields: &[#(#fields),*] }) } } }; let discriminant = match variant.fields { syn::Fields::Unit if is_unit_only => { quote! { let mut buf = [0u8; std::mem::size_of::()]; // Self or any of it's type parameters might have a Drop impl and // since that can not be const evaluated, the discriminant value has // to be wrapped in ManuallyDrop let discriminant = std::mem::ManuallyDrop::new(Self::#variant_ident); unsafe { std::ptr::copy_nonoverlapping( (&raw const discriminant).cast::(), buf.as_mut_ptr(), std::mem::size_of::(), ); } buf } } _ => { quote! { let mut buf = [0u8; std::mem::size_of::()]; let discriminant = #mod_name::Discriminant::#variant_ident; unsafe { std::ptr::copy_nonoverlapping( (&raw const discriminant).cast::(), buf.as_mut_ptr(), std::mem::size_of::<#mod_name::Discriminant>(), ); } buf } } }; quote! { #engine_crate_path::reflection::EnumVariant { name: #variant_name, fields: #fields, discriminant: #engine_crate_path::reflection::EnumDiscriminant { buf: { #discriminant } } } } }) }