diff options
| author | HampusM <hampus@hampusmat.com> | 2022-07-16 12:02:54 +0200 | 
|---|---|---|
| committer | HampusM <hampus@hampusmat.com> | 2022-07-16 12:02:54 +0200 | 
| commit | 05be92b334af1beab3e7a3f2ee7626eb26c47e22 (patch) | |
| tree | 43883a89985b29721961f2001a88db9985bd3485 /syrette_macros/src | |
| parent | 5129384fc0b6f51d315fd528d7769dd638018b88 (diff) | |
feat: add binding factories to DI container
Diffstat (limited to 'syrette_macros/src')
| -rw-r--r-- | syrette_macros/src/lib.rs | 149 | 
1 files changed, 138 insertions, 11 deletions
diff --git a/syrette_macros/src/lib.rs b/syrette_macros/src/lib.rs index 1761534..0302c07 100644 --- a/syrette_macros/src/lib.rs +++ b/syrette_macros/src/lib.rs @@ -2,8 +2,8 @@ use proc_macro::TokenStream;  use quote::{quote, ToTokens};  use syn::{      parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma, -    AttributeArgs, FnArg, GenericArgument, ImplItem, ItemImpl, Meta, NestedMeta, Path, -    PathArguments, Type, TypePath, +    AttributeArgs, ExprMethodCall, FnArg, GenericArgument, ImplItem, ItemImpl, ItemType, +    Meta, NestedMeta, Path, PathArguments, Type, TypeParamBound, TypePath,  };  mod libs; @@ -27,8 +27,16 @@ const IMPL_NO_NEW_METHOD_ERR_MESSAGE: &str =  const IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE: &str =      "The new method of the attached to trait implementation cannot have a self parameter"; -const IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE: &str = -    "All parameters of the new method of the attached to trait implementation must be std::boxed::Box"; +const IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE: &str = concat!( +    "All parameters of the new method of the attached to trait implementation ", +    "must be either std::boxed::Box or std::rc:Rc (for factories)" +); + +const INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE: &str = +    "Invalid aliased trait. Must be 'dyn IFactory'"; + +const INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE: &str = +    "Invalid arguments for 'dyn IFactory'";  fn path_to_string(path: &Path) -> String  { @@ -67,6 +75,10 @@ fn get_fn_arg_type_paths(fn_args: &Punctuated<FnArg, Comma>) -> Vec<TypePath>          match arg {              FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() {                  Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), +                Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { +                    Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), +                    &_ => {} +                },                  &_ => {}              },              FnArg::Receiver(_receiver_fn_arg) => {} @@ -109,8 +121,11 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec<Type>              if arg_type_path_string != "Box"                  && arg_type_path_string != "std::boxed::Box"                  && arg_type_path_string != "boxed::Box" +                && arg_type_path_string != "Rc" +                && arg_type_path_string != "std::rc::Rc" +                && arg_type_path_string != "rc::Rc"              { -                panic!("{}", IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE); +                panic!("{}", IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE);              }              // Assume the type path has a last segment. @@ -202,6 +217,49 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt      let dependency_types = get_dependency_types(&item_impl); +    let get_dependencies = dependency_types.iter().fold( +        Vec::<ExprMethodCall>::new(), +        |mut acc, dep_type| { +            match dep_type { +                Type::TraitObject(dep_type_trait) => { +                    acc.push( +                        parse_str( +                            format!( +                                "di_container.get::<{}>()", +                                dep_type_trait.to_token_stream() +                            ) +                            .as_str(), +                        ) +                        .unwrap(), +                    ); +                } +                Type::Path(dep_type_path) => { +                    let dep_type_path_str = path_to_string(&dep_type_path.path); + +                    let get_method_name = if dep_type_path_str.ends_with("Factory") { +                        "get_factory" +                    } else { +                        "get" +                    }; + +                    acc.push( +                        parse_str( +                            format!( +                                "di_container.{}::<{}>()", +                                get_method_name, dep_type_path_str +                            ) +                            .as_str(), +                        ) +                        .unwrap(), +                    ); +                } +                &_ => {} +            } + +            acc +        }, +    ); +      quote! {          #item_impl @@ -213,13 +271,15 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt                  use error_stack::ResultExt;                  return Ok(Box::new(Self::new( -                    #(di_container.get::<#dependency_types>() +                    #(#get_dependencies                          .change_context(syrette::errors::injectable::ResolveError) -                        .attach_printable(format!( -                            "Unable to resolve a dependency of {}", -                            std::any::type_name::<#self_type_path>() -                        ))?, -                    )* +                        .attach_printable( +                            format!( +                                "Unable to resolve a dependency of {}", +                                std::any::type_name::<#self_type_path>() +                            ) +                        )? +                    ),*                  )));              }          } @@ -229,6 +289,73 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt      .into()  } +#[proc_macro_attribute] +pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +{ +    let type_alias: ItemType = parse(type_alias_stream).unwrap(); + +    let aliased_trait = match &type_alias.ty.as_ref() { +        Type::TraitObject(alias_type) => alias_type, +        &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), +    }; + +    if aliased_trait.bounds.len() != 1 { +        panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); +    } + +    let type_bound = aliased_trait.bounds.first().unwrap(); + +    let trait_bound = match type_bound { +        TypeParamBound::Trait(trait_bound) => trait_bound, +        &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), +    }; + +    let trait_bound_path = &trait_bound.path; + +    if trait_bound_path.segments.is_empty() +        || trait_bound_path.segments.last().unwrap().ident != "IFactory" +    { +        panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); +    } + +    let factory_path_segment = trait_bound_path.segments.last().unwrap(); + +    let factory_path_segment_args = &match &factory_path_segment.arguments { +        syn::PathArguments::AngleBracketed(args) => args, +        &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), +    } +    .args; + +    let factory_arg_types_type = match &factory_path_segment_args[0] { +        GenericArgument::Type(arg_type) => arg_type, +        &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), +    }; + +    let factory_return_type = match &factory_path_segment_args[1] { +        GenericArgument::Type(arg_type) => arg_type, +        &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), +    }; + +    quote! { +        #type_alias + +        syrette::castable_to!( +            syrette::castable_factory::CastableFactory< +                #factory_arg_types_type, +                #factory_return_type +            > => #trait_bound_path +        ); + +        syrette::castable_to!( +            syrette::castable_factory::CastableFactory< +                #factory_arg_types_type, +                #factory_return_type +            > => syrette::castable_factory::AnyFactory +        ); +    } +    .into() +} +  #[doc(hidden)]  #[proc_macro]  pub fn castable_to(input: TokenStream) -> TokenStream  | 
