diff options
Diffstat (limited to 'macros/src/injectable_impl.rs')
-rw-r--r-- | macros/src/injectable_impl.rs | 79 |
1 files changed, 33 insertions, 46 deletions
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 89346e8..f510407 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -2,16 +2,17 @@ use quote::{quote, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::Generics; use syn::{ - parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, - GenericArgument, Ident, ImplItem, ImplItemMethod, ItemImpl, Path, PathArguments, - Type, TypePath, + parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, Ident, + ImplItem, ImplItemMethod, ItemImpl, Path, Type, TypePath, }; +use crate::dependency_type::DependencyType; + const DI_CONTAINER_VAR_NAME: &str = "di_container"; pub struct InjectableImpl { - pub dependency_types: Vec<Type>, + pub dependency_types: Vec<DependencyType>, pub self_type: Type, pub generics: Generics, pub original_impl: ItemImpl, @@ -21,20 +22,17 @@ impl Parse for InjectableImpl { fn parse(input: ParseStream) -> syn::Result<Self> { - match input.parse::<ItemImpl>() { - Ok(impl_parsed_input) => { - match Self::_get_dependency_types(&impl_parsed_input) { - Ok(dependency_types) => Ok(Self { - dependency_types, - self_type: impl_parsed_input.self_ty.as_ref().clone(), - generics: impl_parsed_input.generics.clone(), - original_impl: impl_parsed_input, - }), - Err(error_msg) => Err(input.error(error_msg)), - } - } - Err(_) => Err(input.error("Expected an impl")), - } + let impl_parsed_input = input.parse::<ItemImpl>()?; + + let dependency_types = Self::_get_dependency_types(&impl_parsed_input) + .map_err(|err| input.error(err))?; + + Ok(Self { + dependency_types, + self_type: impl_parsed_input.self_ty.as_ref().clone(), + generics: impl_parsed_input.generics.clone(), + original_impl: impl_parsed_input, + }) } } @@ -81,16 +79,23 @@ impl InjectableImpl } } - fn _create_get_dependencies(dependency_types: &[Type]) -> Vec<ExprMethodCall> + fn _create_get_dependencies( + dependency_types: &[DependencyType], + ) -> Vec<ExprMethodCall> { dependency_types .iter() - .filter_map(|dep_type| match dep_type { + .filter_map(|dep_type| match &dep_type.interface { Type::TraitObject(dep_type_trait) => Some( parse_str( format!( - "{}.get::<{}>()", + "{}.get{}::<{}>()", DI_CONTAINER_VAR_NAME, + if dep_type.ptr == "SingletonPtr" { + "_singleton" + } else { + "" + }, dep_type_trait.to_token_stream() ) .as_str(), @@ -194,12 +199,17 @@ impl InjectableImpl arg_type_path_string == "TransientPtr" || arg_type_path_string == "ptr::TransientPtr" || arg_type_path_string == "syrrete::ptr::TransientPtr" + || arg_type_path_string == "SingletonPtr" + || arg_type_path_string == "ptr::SingletonPtr" + || arg_type_path_string == "syrrete::ptr::SingletonPtr" || arg_type_path_string == "FactoryPtr" || arg_type_path_string == "ptr::FactoryPtr" || arg_type_path_string == "syrrete::ptr::FactoryPtr" } - fn _get_dependency_types(item_impl: &ItemImpl) -> Result<Vec<Type>, &'static str> + fn _get_dependency_types( + item_impl: &ItemImpl, + ) -> Result<Vec<DependencyType>, &'static str> { let new_method_impl_item = match Self::_find_method_by_name(item_impl, "new") { Some(method_item) => Ok(method_item), @@ -223,30 +233,7 @@ impl InjectableImpl Ok(new_method_arg_type_paths .iter() - .filter_map(|arg_type_path| { - // Assume the type path has a last segment. - let last_path_segment = arg_type_path.path.segments.last().unwrap(); - - match &last_path_segment.arguments { - PathArguments::AngleBracketed(angle_bracketed_generic_args) => { - let generic_args = &angle_bracketed_generic_args.args; - - let opt_first_generic_arg = generic_args.first(); - - // Assume a first generic argument exists because TransientPtr and - // FactoryPtr requires one - let first_generic_arg = opt_first_generic_arg.as_ref().unwrap(); - - match first_generic_arg { - GenericArgument::Type(first_generic_arg_type) => { - Some(first_generic_arg_type.clone()) - } - &_ => None, - } - } - &_ => None, - } - }) + .filter_map(|arg_type_path| DependencyType::from_type_path(arg_type_path)) .collect()) } } |