diff options
Diffstat (limited to 'engine-macros/src/lib.rs')
| -rw-r--r-- | engine-macros/src/lib.rs | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/engine-macros/src/lib.rs b/engine-macros/src/lib.rs new file mode 100644 index 0000000..ad6c15f --- /dev/null +++ b/engine-macros/src/lib.rs @@ -0,0 +1,174 @@ +#![deny(clippy::all, clippy::pedantic)] + +use proc_macro::TokenStream; +use quote::{ToTokens, format_ident, quote}; +use syn::punctuated::Punctuated; +use syn::{ + Item, + LitStr, + Path as SynPath, + PredicateType, + Token, + TraitBound, + TypeParamBound, + WhereClause, + WherePredicate, + parse, +}; + +macro_rules! syn_path { + ($first_segment: ident $(::$segment: ident)*) => { + ::syn::Path { + leading_colon: None, + segments: ::syn::punctuated::Punctuated::from_iter([ + syn_path_segment!($first_segment), + $(syn_path_segment!($segment),)* + ]) + } + }; +} + +macro_rules! syn_path_segment { + ($segment: ident) => { + ::syn::PathSegment { + ident: ::proc_macro2::Ident::new( + stringify!($segment), + ::proc_macro2::Span::call_site(), + ), + arguments: ::syn::PathArguments::None, + } + }; +} + +#[proc_macro_derive(Reflection)] +pub fn reflection_derive(input: TokenStream) -> TokenStream +{ + let input = parse::<Item>(input).unwrap(); + + let input = match input { + Item::Struct(input) => input, + Item::Enum(_) => unimplemented!(), + _ => panic!("Invalid input"), + }; + + let engine_crate_path = find_engine_crate_path().unwrap(); + + let input_ident = input.ident; + + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); + + let mut where_clause = where_clause.cloned().unwrap_or_else(|| WhereClause { + where_token: <Token![where]>::default(), + predicates: Punctuated::new(), + }); + + where_clause + .predicates + .extend(input.fields.iter().map(|field| { + WherePredicate::Type(PredicateType { + lifetimes: None, + bounded_ty: field.ty.clone(), + colon_token: <Token![:]>::default(), + bounds: [TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: engine_crate_path.join(syn_path!(reflection::With)), + })] + .into_iter() + .collect(), + }) + })); + + let fields = input.fields.into_iter().enumerate().map(|(index, field)| { + let field_ident = field.ident.unwrap_or_else(|| format_ident!("{index}")); + + let field_type = &field.ty; + + let field_name = LitStr::new(&field_ident.to_string(), field_ident.span()); + + // since std::any::type_name as const is not stable yet + let field_type_name = field_type.to_token_stream().to_string(); + + quote! { + #engine_crate_path::reflection::StructField { + name: #field_name, + index: #index, + layout: std::alloc::Layout::new::<#field_type>(), + byte_offset: std::mem::offset_of!(Self, #field_ident), + type_id: std::any::TypeId::of::<#field_type>(), + type_name: #field_type_name, + reflection: + #engine_crate_path::reflection::__private::get_type_reflection::< + #field_type + >() + } + } + }); + + quote! { + impl #impl_generics #engine_crate_path::reflection::With for + #input_ident #type_generics #where_clause + { + const REFLECTION: &#engine_crate_path::reflection::Reflection = + &const { + #engine_crate_path::reflection::Reflection::Struct( + #engine_crate_path::reflection::Struct { + fields: &[ + #(#fields),* + ] + } + ) + }; + + fn reflection() -> &'static #engine_crate_path::reflection::Reflection + { + Self::REFLECTION + } + + fn get_reflection(&self) -> &'static #engine_crate_path::reflection::Reflection + { + Self::reflection() + } + } + } + .into() +} + +fn find_engine_crate_path() -> Option<SynPath> +{ + let cargo_crate_name = std::env::var("CARGO_CRATE_NAME").ok()?; + let cargo_pkg_name = std::env::var("CARGO_PKG_NAME").ok()?; + + if cargo_pkg_name == "engine" && cargo_crate_name != "engine" { + // Macro is used by a crate example/test/benchmark + return Some(syn_path!(engine)); + } + + if cargo_crate_name == "engine" { + return Some(syn_path!(crate)); + } + + Some(syn_path!(engine)) +} + +trait SynPathExt +{ + fn join(&self, other: Self) -> Self; +} + +impl SynPathExt for SynPath +{ + fn join(&self, other: Self) -> Self + { + Self { + leading_colon: self.leading_colon.clone(), + segments: self + .segments + .iter() + .chain(&other.segments) + .cloned() + .collect(), + } + } +} |
