From c4eccc81d9bfa472197a4f302df1c967081a0be5 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 20 Aug 2022 15:11:54 +0200 Subject: fix: make DI container get_factory calls in the injectable macro valid --- macros/src/injectable_impl.rs | 89 +++++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 38 deletions(-) (limited to 'macros/src') diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 29e0094..d125f05 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -1,14 +1,15 @@ -use quote::{quote, ToTokens}; +use quote::{format_ident, quote, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::Generics; use syn::{ - parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, Ident, - ImplItem, ImplItemMethod, ItemImpl, Path, Type, TypePath, + parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, ImplItem, + ImplItemMethod, ItemImpl, Path, Type, TypePath, }; use crate::dependency_type::DependencyType; const DI_CONTAINER_VAR_NAME: &str = "di_container"; +const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history"; pub struct InjectableImpl { @@ -47,16 +48,10 @@ impl InjectableImpl original_impl, } = self; - let di_container_var: Ident = parse_str(DI_CONTAINER_VAR_NAME).unwrap(); + let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); + let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - let get_dep_method_names = Self::_create_get_dep_method_names(dependency_types); - - let get_dependencies = get_dep_method_names.iter().map(|get_dep_method_name| { - parse_str::( - format!("{}(dependency_history.clone())", get_dep_method_name).as_str(), - ) - .unwrap() - }); + let get_dep_method_calls = Self::_create_get_dep_method_calls(dependency_types); let maybe_doc_hidden = if no_doc_hidden { quote! {} @@ -68,22 +63,22 @@ impl InjectableImpl let maybe_prevent_circular_deps = if cfg!(feature = "prevent-circular") { quote! { - if dependency_history.contains(&self_type_name) { - dependency_history.push(self_type_name); + if #dependency_history_var.contains(&self_type_name) { + #dependency_history_var.push(self_type_name); return Err( report!(ResolveError) .attach_printable(format!( "Detected circular dependencies. {}", syrette::dependency_trace::create_dependency_trace( - dependency_history.as_slice(), + #dependency_history_var.as_slice(), self_type_name ) )) ); } - dependency_history.push(self_type_name); + #dependency_history_var.push(self_type_name); } } else { quote! {} @@ -96,7 +91,7 @@ impl InjectableImpl impl #generics syrette::interfaces::injectable::Injectable for #self_type { fn resolve( #di_container_var: &syrette::DIContainer, - mut dependency_history: Vec<&'static str>, + mut #dependency_history_var: Vec<&'static str>, ) -> syrette::libs::error_stack::Result< syrette::ptr::TransientPtr, syrette::errors::injectable::ResolveError> @@ -109,7 +104,7 @@ impl InjectableImpl #maybe_prevent_circular_deps return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dependencies + #(#get_dep_method_calls .change_context(ResolveError) .attach_printable( format!( @@ -124,34 +119,52 @@ impl InjectableImpl } } - fn _create_get_dep_method_names(dependency_types: &[DependencyType]) -> Vec + fn _create_get_dep_method_calls( + dependency_types: &[DependencyType], + ) -> Vec { dependency_types .iter() .filter_map(|dep_type| match &dep_type.interface { - Type::TraitObject(dep_type_trait) => Some(format!( - "{}.get_{}::<{}>", - DI_CONTAINER_VAR_NAME, - if dep_type.ptr == "SingletonPtr" { - "singleton_with_history" - } else { - "with_history" - }, - dep_type_trait.to_token_stream() - )), + Type::TraitObject(dep_type_trait) => parse_str::( + format!( + "{}.get_{}::<{}>({}.clone())", + DI_CONTAINER_VAR_NAME, + if dep_type.ptr == "SingletonPtr" { + "singleton_with_history" + } else { + "with_history" + }, + dep_type_trait.to_token_stream(), + DEPENDENCY_HISTORY_VAR_NAME + ) + .as_str(), + ) + .ok(), Type::Path(dep_type_path) => { let dep_type_path_str = Self::_path_to_string(&dep_type_path.path); - let get_method_name = if dep_type_path_str.ends_with("Factory") { - "factory" + if dep_type_path_str.ends_with("Factory") { + parse_str( + format!( + "{}.get_factory::<{}>()", + DI_CONTAINER_VAR_NAME, dep_type_path_str + ) + .as_str(), + ) + .ok() } else { - "with_history" - }; - - Some(format!( - "get_{}.{}::<{}>", - DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str - )) + parse_str( + format!( + "{}.get_with_history::<{}>({}.clone())", + DI_CONTAINER_VAR_NAME, + dep_type_path_str, + DEPENDENCY_HISTORY_VAR_NAME + ) + .as_str(), + ) + .ok() + } } &_ => None, }) -- cgit v1.2.3-18-g5258