From bcfb025d8996b9e2d7d7386a3e372331ad11985f Mon Sep 17 00:00:00 2001 From: HampusM Date: Thu, 11 Jun 2026 19:32:17 +0200 Subject: fix(engine-macros): make generic enums able to derive Reflection --- engine-macros/src/reflection/enum_impl.rs | 67 +++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 16 deletions(-) (limited to 'engine-macros/src') diff --git a/engine-macros/src/reflection/enum_impl.rs b/engine-macros/src/reflection/enum_impl.rs index 7f7d25e..d124db2 100644 --- a/engine-macros/src/reflection/enum_impl.rs +++ b/engine-macros/src/reflection/enum_impl.rs @@ -1,4 +1,4 @@ -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; use crate::reflection::field::{generate as generate_field, ReflectionFieldGenOptions}; use crate::util::find_engine_crate_path; @@ -93,18 +93,45 @@ pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream let ident = format_ident!("VariantFields{}", variant.ident); + let generics_phantom_data_elems = + input.generics.params.iter().filter_map(|generic_param| { + match generic_param { + syn::GenericParam::Type(type_param) => { + Some(type_param.ident.clone().into_token_stream()) + } + syn::GenericParam::Lifetime(_) => { + unimplemented!(); + } + syn::GenericParam::Const(_const_param) => None, + } + }); + + let generics_phantom_data = quote! { + std::marker::PhantomData<(#(#generics_phantom_data_elems),*)> + }; + match variant.fields { syn::Fields::Named(_) => quote! { #[repr(C)] - pub struct #ident { #(#fields),* } + pub struct #ident #impl_generics + #where_clause + { + #(#fields,)* + _pd: #generics_phantom_data + } }, syn::Fields::Unnamed(_) => quote! { #[repr(C)] - pub struct #ident ( #(#fields),* ); + pub struct #ident #impl_generics ( + #(#fields,)* + #generics_phantom_data + ) + #where_clause; }, syn::Fields::Unit => quote! { #[repr(C)] - pub struct #ident; + pub struct #ident #impl_generics (#generics_phantom_data) + #where_clause; }, } }); @@ -116,7 +143,8 @@ pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream format_ident!("VariantFields{}", variant.ident); quote! { - pub #variant_ident: std::mem::ManuallyDrop<#variant_fields_struct_ident> + pub #variant_ident: + std::mem::ManuallyDrop<#variant_fields_struct_ident #type_generics> } }); @@ -145,10 +173,11 @@ pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream use super::*; #[repr(C)] - pub struct Equivalent + pub struct Equivalent #impl_generics + #where_clause { pub tag: Discriminant, - pub payload: Fields + pub payload: Fields #type_generics } #[repr(#discriminant_enum_repr)] @@ -158,7 +187,8 @@ pub fn generate(input: syn::ItemEnum) -> proc_macro2::TokenStream } #[repr(C)] - pub union Fields + pub union Fields #impl_generics + #where_clause { #(#fields_union_fields),* } @@ -216,6 +246,8 @@ fn generate_variants<'a>( { let mod_name = format_ident!("__engine_private_{}", input.ident); + let (_, type_generics, _) = input.generics.split_for_impl(); + input.variants.iter().map(move |variant| { let variant_name = syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); @@ -242,11 +274,12 @@ fn generate_variants<'a>( if let Some(field_ident) = &field.ident { quote! { std::mem::offset_of!( - #mod_name::Equivalent, + #mod_name::Equivalent #type_generics, payload.#variant_ident ) + std::mem::offset_of!( - #mod_name::#variant_fields_struct_ident, + #mod_name::#variant_fields_struct_ident + #type_generics, #field_ident ) } @@ -288,11 +321,12 @@ fn generate_variants<'a>( quote! { std::mem::offset_of!( - #mod_name::Equivalent, + #mod_name::Equivalent #type_generics, payload.#variant_ident ) + std::mem::offset_of!( - #mod_name::#variant_fields_struct_ident, + #mod_name::#variant_fields_struct_ident + #type_generics, #field_index ) } @@ -311,20 +345,21 @@ fn generate_variants<'a>( } }; - let enum_ident = &input.ident; - let discriminant = match variant.fields { syn::Fields::Unit => { quote! { let mut buf = [0u8; std::mem::size_of::()]; - let discriminant = #enum_ident::#variant_ident; + // Self or any of it's type parameters might have a Drop impl and + // since that can not be const evaluated, the discriminant value has + // to be wrapped in ManuallyDrop + let discriminant = std::mem::ManuallyDrop::new(Self::#variant_ident); unsafe { std::ptr::copy_nonoverlapping( (&raw const discriminant).cast::(), buf.as_mut_ptr(), - std::mem::size_of::<#enum_ident>(), + std::mem::size_of::(), ); } -- cgit v1.2.3-18-g5258