diff options
Diffstat (limited to 'syrette_macros/src')
-rw-r--r-- | syrette_macros/src/lib.rs | 149 |
1 files changed, 138 insertions, 11 deletions
diff --git a/syrette_macros/src/lib.rs b/syrette_macros/src/lib.rs index 1761534..0302c07 100644 --- a/syrette_macros/src/lib.rs +++ b/syrette_macros/src/lib.rs @@ -2,8 +2,8 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma, - AttributeArgs, FnArg, GenericArgument, ImplItem, ItemImpl, Meta, NestedMeta, Path, - PathArguments, Type, TypePath, + AttributeArgs, ExprMethodCall, FnArg, GenericArgument, ImplItem, ItemImpl, ItemType, + Meta, NestedMeta, Path, PathArguments, Type, TypeParamBound, TypePath, }; mod libs; @@ -27,8 +27,16 @@ const IMPL_NO_NEW_METHOD_ERR_MESSAGE: &str = const IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE: &str = "The new method of the attached to trait implementation cannot have a self parameter"; -const IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE: &str = - "All parameters of the new method of the attached to trait implementation must be std::boxed::Box"; +const IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE: &str = concat!( + "All parameters of the new method of the attached to trait implementation ", + "must be either std::boxed::Box or std::rc:Rc (for factories)" +); + +const INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE: &str = + "Invalid aliased trait. Must be 'dyn IFactory'"; + +const INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE: &str = + "Invalid arguments for 'dyn IFactory'"; fn path_to_string(path: &Path) -> String { @@ -67,6 +75,10 @@ fn get_fn_arg_type_paths(fn_args: &Punctuated<FnArg, Comma>) -> Vec<TypePath> match arg { FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() { Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), + &_ => {} + }, &_ => {} }, FnArg::Receiver(_receiver_fn_arg) => {} @@ -109,8 +121,11 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec<Type> if arg_type_path_string != "Box" && arg_type_path_string != "std::boxed::Box" && arg_type_path_string != "boxed::Box" + && arg_type_path_string != "Rc" + && arg_type_path_string != "std::rc::Rc" + && arg_type_path_string != "rc::Rc" { - panic!("{}", IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE); + panic!("{}", IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE); } // Assume the type path has a last segment. @@ -202,6 +217,49 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt let dependency_types = get_dependency_types(&item_impl); + let get_dependencies = dependency_types.iter().fold( + Vec::<ExprMethodCall>::new(), + |mut acc, dep_type| { + match dep_type { + Type::TraitObject(dep_type_trait) => { + acc.push( + parse_str( + format!( + "di_container.get::<{}>()", + dep_type_trait.to_token_stream() + ) + .as_str(), + ) + .unwrap(), + ); + } + Type::Path(dep_type_path) => { + let dep_type_path_str = path_to_string(&dep_type_path.path); + + let get_method_name = if dep_type_path_str.ends_with("Factory") { + "get_factory" + } else { + "get" + }; + + acc.push( + parse_str( + format!( + "di_container.{}::<{}>()", + get_method_name, dep_type_path_str + ) + .as_str(), + ) + .unwrap(), + ); + } + &_ => {} + } + + acc + }, + ); + quote! { #item_impl @@ -213,13 +271,15 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt use error_stack::ResultExt; return Ok(Box::new(Self::new( - #(di_container.get::<#dependency_types>() + #(#get_dependencies .change_context(syrette::errors::injectable::ResolveError) - .attach_printable(format!( - "Unable to resolve a dependency of {}", - std::any::type_name::<#self_type_path>() - ))?, - )* + .attach_printable( + format!( + "Unable to resolve a dependency of {}", + std::any::type_name::<#self_type_path>() + ) + )? + ),* ))); } } @@ -229,6 +289,73 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt .into() } +#[proc_macro_attribute] +pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +{ + let type_alias: ItemType = parse(type_alias_stream).unwrap(); + + let aliased_trait = match &type_alias.ty.as_ref() { + Type::TraitObject(alias_type) => alias_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), + }; + + if aliased_trait.bounds.len() != 1 { + panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); + } + + let type_bound = aliased_trait.bounds.first().unwrap(); + + let trait_bound = match type_bound { + TypeParamBound::Trait(trait_bound) => trait_bound, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), + }; + + let trait_bound_path = &trait_bound.path; + + if trait_bound_path.segments.is_empty() + || trait_bound_path.segments.last().unwrap().ident != "IFactory" + { + panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); + } + + let factory_path_segment = trait_bound_path.segments.last().unwrap(); + + let factory_path_segment_args = &match &factory_path_segment.arguments { + syn::PathArguments::AngleBracketed(args) => args, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + } + .args; + + let factory_arg_types_type = match &factory_path_segment_args[0] { + GenericArgument::Type(arg_type) => arg_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + }; + + let factory_return_type = match &factory_path_segment_args[1] { + GenericArgument::Type(arg_type) => arg_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + }; + + quote! { + #type_alias + + syrette::castable_to!( + syrette::castable_factory::CastableFactory< + #factory_arg_types_type, + #factory_return_type + > => #trait_bound_path + ); + + syrette::castable_to!( + syrette::castable_factory::CastableFactory< + #factory_arg_types_type, + #factory_return_type + > => syrette::castable_factory::AnyFactory + ); + } + .into() +} + #[doc(hidden)] #[proc_macro] pub fn castable_to(input: TokenStream) -> TokenStream |