use quote::{format_ident, quote}; use crate::reflection::field::{generate as generate_field, ReflectionFieldGenOptions}; use crate::util::find_engine_crate_path; const PRIMITIVE_REPRS: &[&'static str] = &[ "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize", ]; pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream { let engine_crate_path = find_engine_crate_path().unwrap(); let input_ident = &input.ident; let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); let variant_lookup_match_arms = input.variants.iter().enumerate().map(|(index, variant)| { let variant_ident = &variant.ident; let pattern = match variant.fields { syn::Fields::Unit => quote! { Self::#variant_ident }, syn::Fields::Named(_) => quote! { Self::#variant_ident { .. } }, syn::Fields::Unnamed(_) => quote! { Self::#variant_ident(..) }, }; quote! { #pattern => &enum_reflection.variants[#index] } }); let is_unit_only = input .variants .iter() .all(|variant| matches!(variant.fields, syn::Fields::Unit)); let mod_name = format_ident!("__engine_private_{input_ident}"); let mut has_c_repr = false; let mut primitive_repr: Option = None; for attr in &input.attrs { let syn::Meta::List(attr_meta) = &attr.meta else { continue; }; if !attr_meta.path.is_ident("repr") { continue; } attr_meta .parse_nested_meta(|nested_meta| { let Some(meta_ident) = nested_meta.path.get_ident() else { return Ok(()); }; if meta_ident == "C" { has_c_repr = true; } else if PRIMITIVE_REPRS.contains(&meta_ident.to_string().as_str()) { primitive_repr = Some(meta_ident.clone()); } else { return Err(nested_meta.error(concat!( "Unsupported representation. ", "Only C and primitive representations are allowed" ))); } Ok(()) }) .unwrap(); } if !has_c_repr && primitive_repr.is_none() { panic!("Enums must have a C or primitive representation to derive Reflection"); } let mod_content = if is_unit_only { quote! {} } else { let variant_fields_structs = input.variants.iter().map(|variant| { let fields = variant.fields.iter().map(|field| { let field_type = &field.ty; if let Some(field_ident) = &field.ident { quote! { pub #field_ident: #field_type } } else { quote! { pub #field_type } } }); let ident = format_ident!("VariantFields{}", variant.ident); match variant.fields { syn::Fields::Named(_) => quote! { #[repr(C)] pub struct #ident { #(#fields),* } }, syn::Fields::Unnamed(_) => quote! { #[repr(C)] pub struct #ident ( #(#fields),* ); }, syn::Fields::Unit => quote! { #[repr(C)] pub struct #ident; }, } }); let fields_union_fields = input.variants.iter().map(|variant| { let variant_ident = &variant.ident; let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); quote! { pub #variant_ident: std::mem::ManuallyDrop<#variant_fields_struct_ident> } }); let discriminant_enum_variants = input.variants.iter().map(|variant| { let variant_ident = &variant.ident; if let Some((_, discriminant)) = &variant.discriminant { quote! { #variant_ident = #discriminant } } else { quote! { #variant_ident } } }); // If the enum has both a C & primitive repr, the primitive repr must be used let discriminant_enum_repr = if let Some(primitive_repr) = primitive_repr { quote! { #primitive_repr } } else if has_c_repr { quote! { C } } else { unreachable!(); }; quote! { #![allow(non_snake_case, dead_code)] use super::*; #[repr(C)] pub struct Equivalent { pub tag: Discriminant, pub payload: Fields } #[repr(#discriminant_enum_repr)] pub enum Discriminant { #(#discriminant_enum_variants),* } #[repr(C)] pub union Fields { #(#fields_union_fields),* } #(#variant_fields_structs)* } }; let variants = generate_variants(&input, &engine_crate_path); quote! { #[doc(hidden)] mod #mod_name { #mod_content } unsafe impl #impl_generics #engine_crate_path::reflection::Reflection for #input_ident #type_generics #where_clause { const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = &const { #engine_crate_path::reflection::Type::Enum( #engine_crate_path::reflection::Enum { variants: &[#(#variants),*], is_unit_only: #is_unit_only } ) }; } unsafe impl #impl_generics #engine_crate_path::reflection::EnumReflectionExt for #input_ident #type_generics #where_clause { fn get_variant_reflection(&self) -> &'static #engine_crate_path::reflection::EnumVariant { let enum_reflection = unsafe { ::TYPE_REFLECTION .as_enum() .unwrap_unchecked() }; match self { #(#variant_lookup_match_arms),* } } } } } fn generate_variants<'a>( input: &'a syn::ItemEnum, engine_crate_path: &'a syn::Path, ) -> impl Iterator + use<'a> { let mod_name = format_ident!("__engine_private_{}", input.ident); input.variants.iter().map(move |variant| { let variant_name = syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); let variant_ident = &variant.ident; let variant_fields_struct_ident = format_ident!("VariantFields{}", variant.ident); let fields = match &variant.fields { syn::Fields::Unit => quote! { None }, syn::Fields::Named(named_fields) => { let fields = named_fields.named.iter().enumerate().map( |(variant_field_index, variant_field)| { generate_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { // enum variant fields are always public field_vis_override: Some(syn::Visibility::Public( ::default(), )), gen_get_byte_offset: &|field| { if let Some(field_ident) = &field.ident { quote! { std::mem::offset_of!( #mod_name::Equivalent, payload.#variant_ident ) + std::mem::offset_of!( #mod_name::#variant_fields_struct_ident, #field_ident ) } } else { unreachable!(); } }, }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Named { fields: &[#(#fields),*] }) } } syn::Fields::Unnamed(unnamed_fields) => { let fields = unnamed_fields.unnamed.iter().enumerate().map( |(variant_field_index, variant_field)| { generate_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { // enum variant fields are always public field_vis_override: Some(syn::Visibility::Public( ::default(), )), gen_get_byte_offset: &|field| { if let Some(_) = &field.ident { unreachable!() } else { let field_index = proc_macro2::Literal::usize_unsuffixed( variant_field_index, ); quote! { std::mem::offset_of!( #mod_name::Equivalent, payload.#variant_ident ) + std::mem::offset_of!( #mod_name::#variant_fields_struct_ident, #field_index ) } } }, }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Unnamed { fields: &[#(#fields),*] }) } } }; let enum_ident = &input.ident; let discriminant = match variant.fields { syn::Fields::Unit => { quote! { let mut buf = [0u8; std::mem::size_of::()]; let discriminant = #enum_ident::#variant_ident; unsafe { std::ptr::copy_nonoverlapping( (&raw const discriminant).cast::(), buf.as_mut_ptr(), std::mem::size_of::<#enum_ident>(), ); } buf } } _ => { quote! { let mut buf = [0u8; std::mem::size_of::()]; let discriminant = #mod_name::Discriminant::#variant_ident; unsafe { std::ptr::copy_nonoverlapping( (&raw const discriminant).cast::(), buf.as_mut_ptr(), std::mem::size_of::<#mod_name::Discriminant>(), ); } buf } } }; quote! { #engine_crate_path::reflection::EnumVariant { name: #variant_name, fields: #fields, discriminant: #engine_crate_path::reflection::EnumDiscriminant { buf: { #discriminant } } } } }) }