From 080cc42bb1da09059dbc35049a7ded0649961e0c Mon Sep 17 00:00:00 2001 From: HampusM Date: Mon, 29 Aug 2022 20:52:56 +0200 Subject: feat: implement async functionality --- macros/src/injectable_impl.rs | 102 ++++++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 23 deletions(-) (limited to 'macros/src/injectable_impl.rs') diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 990b148..3565ef9 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; use crate::dependency::Dependency; use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase; use crate::util::syn_path::syn_path_to_string; const DI_CONTAINER_VAR_NAME: &str = "di_container"; @@ -39,7 +40,8 @@ impl Parse for InjectableImpl impl InjectableImpl { - pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream + pub fn expand(&self, no_doc_hidden: bool, is_async: bool) + -> proc_macro2::TokenStream { let Self { dependencies, @@ -51,8 +53,6 @@ impl InjectableImpl let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies); - let maybe_doc_hidden = if no_doc_hidden { quote! {} } else { @@ -81,36 +81,78 @@ impl InjectableImpl quote! {} }; - quote! { - #original_impl + let injectable_impl = if is_async { + let async_get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, true); + + quote! { + #maybe_doc_hidden + #[syrette::libs::async_trait::async_trait] + impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type + { + async fn resolve( + #di_container_var: &syrette::async_di_container::AsyncDIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#async_get_dep_method_calls),* + ))); + } + } - #maybe_doc_hidden - impl #generics syrette::interfaces::injectable::Injectable for #self_type { - fn resolve( - #di_container_var: &syrette::DIContainer, - mut #dependency_history_var: Vec<&'static str>, - ) -> Result< - syrette::ptr::TransientPtr, - syrette::errors::injectable::InjectableError> + } + } else { + let get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, false); + + quote! { + #maybe_doc_hidden + impl #generics syrette::interfaces::injectable::Injectable for #self_type { - use std::any::type_name; + fn resolve( + #di_container_var: &syrette::DIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; - use syrette::errors::injectable::InjectableError; + use syrette::errors::injectable::InjectableError; - let self_type_name = type_name::<#self_type>(); + let self_type_name = type_name::<#self_type>(); - #maybe_prevent_circular_deps + #maybe_prevent_circular_deps - return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dep_method_calls),* - ))); + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#get_dep_method_calls),* + ))); + } } } + }; + + quote! { + #original_impl + + #injectable_impl } } fn create_get_dep_method_calls( dependencies: &[Dependency], + is_async: bool, ) -> Vec { dependencies @@ -146,11 +188,25 @@ impl InjectableImpl .map(|(method_call, dep_type)| { let ptr_name = dep_type.ptr.to_string(); - let to_ptr = - format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase()); + let to_ptr = format_ident!( + "{}", + camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) + ); + + let do_method_call = if is_async { + quote! { #method_call.await } + } else { + quote! { #method_call } + }; + + let resolve_failed_error = if is_async { + quote! { InjectableError::AsyncResolveFailed } + } else { + quote! { InjectableError::ResolveFailed } + }; quote! { - #method_call.map_err(|err| InjectableError::ResolveFailed { + #do_method_call.map_err(|err| #resolve_failed_error { reason: Box::new(err), affected: self_type_name })?.#to_ptr().unwrap() -- cgit v1.2.3-18-g5258