#![deny(clippy::all, clippy::pedantic)] use std::fmt::Write; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse, Field as SynField, Fields, Item, ItemEnum, ItemStruct, LitStr, Path as SynPath, Visibility as SynVisibility, }; macro_rules! syn_path { ($first_segment: ident $(::$segment: ident)*) => { ::syn::Path { leading_colon: None, segments: ::syn::punctuated::Punctuated::from_iter([ syn_path_segment!($first_segment), $(syn_path_segment!($segment),)* ]) } }; } macro_rules! syn_path_segment { ($segment: ident) => { ::syn::PathSegment { ident: ::proc_macro2::Ident::new( stringify!($segment), ::proc_macro2::Span::call_site(), ), arguments: ::syn::PathArguments::None, } }; } #[proc_macro_derive(Reflection)] pub fn reflection_derive(input: TokenStream) -> TokenStream { let input = parse::(input).unwrap(); match input { Item::Struct(input) => generate_struct_reflection_impl(input), Item::Enum(input) => generate_enum_reflection_impl(input), _ => panic!("Invalid input"), } } fn generate_struct_reflection_impl(input: ItemStruct) -> TokenStream { let engine_crate_path = find_engine_crate_path().unwrap(); let input_ident = input.ident; let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); let fields = input .fields .into_iter() .enumerate() .map(|(field_index, field)| { gen_reflection_field( &field, field_index, &engine_crate_path, ReflectionFieldGenOptions { field_vis_override: None, include_byte_offset: true, }, ) }); quote! { unsafe impl #impl_generics #engine_crate_path::reflection::Reflection for #input_ident #type_generics #where_clause { const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = &const { #engine_crate_path::reflection::Type::Struct( #engine_crate_path::reflection::Struct { fields: &[ #(#fields),* ] } ) }; } } .into() } fn generate_enum_reflection_impl(input: ItemEnum) -> TokenStream { let engine_crate_path = find_engine_crate_path().unwrap(); let input_ident = input.ident; let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); let variants = input.variants.iter().map(|variant| { let variant_name = LitStr::new(&variant.ident.to_string(), variant.ident.span()); let fields = match &variant.fields { Fields::Unit => quote! { None }, Fields::Named(named_fields) => { let fields = named_fields.named.iter().enumerate().map( |(variant_field_index, variant_field)| { gen_reflection_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { field_vis_override: None, include_byte_offset: false, }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Named { fields: &[#(#fields),*] }) } } Fields::Unnamed(unnamed_fields) => { let fields = unnamed_fields.unnamed.iter().enumerate().map( |(variant_field_index, variant_field)| { gen_reflection_field( variant_field, variant_field_index, &engine_crate_path, ReflectionFieldGenOptions { field_vis_override: None, include_byte_offset: false, }, ) }, ); quote! { Some(#engine_crate_path::reflection::EnumVariantFields::Unnamed { fields: &[#(#fields),*] }) } } }; quote! { #engine_crate_path::reflection::EnumVariant { name: #variant_name, fields: #fields } } }); let variant_lookup_match_arms = input.variants.iter().enumerate().map(|(index, variant)| { let variant_ident = &variant.ident; let pattern = match variant.fields { Fields::Unit => quote! { Self::#variant_ident }, Fields::Named(_) => quote! { Self::#variant_ident { .. } }, Fields::Unnamed(_) => quote! { Self::#variant_ident(..) }, }; quote! { #pattern => &enum_reflection.variants[#index] } }); let is_unit_only = input .variants .iter() .all(|variant| matches!(variant.fields, Fields::Unit)); quote! { unsafe impl #impl_generics #engine_crate_path::reflection::Reflection for #input_ident #type_generics #where_clause { const TYPE_REFLECTION: &#engine_crate_path::reflection::Type = &const { #engine_crate_path::reflection::Type::Enum( #engine_crate_path::reflection::Enum { variants: &[#(#variants),*], is_unit_only: #is_unit_only } ) }; } unsafe impl #impl_generics #engine_crate_path::reflection::EnumReflectionExt for #input_ident #type_generics #where_clause { fn get_variant_reflection(&self) -> &'static #engine_crate_path::reflection::EnumVariant { let enum_reflection = unsafe { ::TYPE_REFLECTION .as_enum() .unwrap_unchecked() }; match self { #(#variant_lookup_match_arms),* } } } } .into() } struct ReflectionFieldGenOptions { field_vis_override: Option, include_byte_offset: bool, } fn gen_reflection_field( field: &SynField, field_index: usize, engine_crate_path: &SynPath, options: ReflectionFieldGenOptions, ) -> proc_macro2::TokenStream { let field_ident = &field.ident; let field_type = &field.ty; let field_name = if let Some(field_ident) = field_ident { let field_name = LitStr::new(&field_ident.to_string(), field_ident.span()); quote! { Some(#field_name) } } else { quote! { None } }; let field_byte_offset = if options.include_byte_offset { if let Some(field_ident) = field_ident { quote! { Some(std::mem::offset_of!(Self, #field_ident)) } } else { quote! { Some(std::mem::offset_of!(Self, #field_index)) } } } else { quote! { None } }; // since std::any::type_name as const is not stable yet let field_type_name = field_type.to_token_stream().to_string(); let field_vis = options.field_vis_override.as_ref().unwrap_or(&field.vis); let field_reflection_vis = gen_reflection_visibility_path(field_vis, &engine_crate_path); quote! { #engine_crate_path::reflection::Field { name: #field_name, index: #field_index, layout: std::alloc::Layout::new::<#field_type>(), byte_offset: #field_byte_offset, type_id: std::any::TypeId::of::<#field_type>(), type_name: #field_type_name, get_type: #engine_crate_path::reflection::FnWithDebug::new(|| { struct SpecializationTarget(std::marker::PhantomData); trait FieldHasReflection { fn field_type_reflection(&self) -> Option<&'static #engine_crate_path::reflection::Type>; } trait FieldDoesNotHaveReflection { fn field_type_reflection(&self) -> Option<&'static #engine_crate_path::reflection::Type>; } impl FieldDoesNotHaveReflection for &SpecializationTarget { fn field_type_reflection(&self) -> Option<&'static #engine_crate_path::reflection::Type> { None } } impl FieldHasReflection for SpecializationTarget where Field: #engine_crate_path::reflection::Reflection { fn field_type_reflection(&self) -> Option<&'static #engine_crate_path::reflection::Type> { Some(Field::type_reflection()) } } (&SpecializationTarget::<#field_type>(std::marker::PhantomData)) .field_type_reflection() }), visibility: #field_reflection_vis } } } fn gen_reflection_visibility_path( visibility: &SynVisibility, engine_crate_path: &SynPath, ) -> proc_macro2::TokenStream { match visibility { SynVisibility::Public(_) => { quote! { #engine_crate_path::reflection::Visibility::Pub } } SynVisibility::Restricted(vis_restricted) => { let vis_scope = if vis_restricted.in_token.is_some() { let in_path = syn_path_to_string(&vis_restricted.path); quote! { #engine_crate_path::reflection::VisibilityScope::In( std::borrow::Cow::Borrowed(#in_path) ) } } else { let Some(scope) = vis_restricted.path.get_ident() else { unreachable!(); }; if scope == "crate" { quote! { #engine_crate_path::reflection::VisibilityScope::Crate } } else if scope == "super" { quote! { #engine_crate_path::reflection::VisibilityScope::Super } } else if scope == "self" { quote! { #engine_crate_path::reflection::VisibilityScope::SelfModule } } else { unreachable!(); } }; quote! { #engine_crate_path::reflection::Visibility::PubScoped(#vis_scope) } } SynVisibility::Inherited => { quote! { #engine_crate_path::reflection::Visibility::Private } } } } fn find_engine_crate_path() -> Option { let cargo_crate_name = std::env::var("CARGO_CRATE_NAME").ok()?; let cargo_pkg_name = std::env::var("CARGO_PKG_NAME").ok()?; if cargo_pkg_name == "engine" && cargo_crate_name != "engine" { // Macro is used by a crate example/test/benchmark return Some(syn_path!(engine)); } if cargo_crate_name == "engine" { return Some(syn_path!(crate)); } Some(syn_path!(engine)) } fn syn_path_to_string(path: &syn::Path) -> String { let mut output = String::with_capacity(2 + path.segments.len() * 8); if let Some(leading_colon) = path.leading_colon { write!(output, "{}", leading_colon.to_token_stream()).unwrap(); } for (segment, punct) in path.segments.pairs().map(syn::punctuated::Pair::into_tuple) { let segment_ident = &segment.ident; write!(output, "{segment_ident}",).unwrap(); if let Some(punct) = punct { write!(output, "{}", punct.to_token_stream()).unwrap(); } } output }