diff options
author | HampusM <hampus@hampusmat.com> | 2022-09-24 13:13:20 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-09-24 13:13:20 +0200 |
commit | 695f90bf900015df1e2728445f833dabced838a9 (patch) | |
tree | c68f2b483e3d20f400d27d4df159b2aec94d072f /macros/src/injectable | |
parent | 3ed020425bfd1fc5fedfa89a7ce20207bedcf5bc (diff) |
refactor: reorganize modules in the macros crate
Diffstat (limited to 'macros/src/injectable')
-rw-r--r-- | macros/src/injectable/dependency.rs | 81 | ||||
-rw-r--r-- | macros/src/injectable/implementation.rs | 261 | ||||
-rw-r--r-- | macros/src/injectable/macro_args.rs | 67 | ||||
-rw-r--r-- | macros/src/injectable/mod.rs | 4 | ||||
-rw-r--r-- | macros/src/injectable/named_attr_input.rs | 21 |
5 files changed, 434 insertions, 0 deletions
diff --git a/macros/src/injectable/dependency.rs b/macros/src/injectable/dependency.rs new file mode 100644 index 0000000..2c5e0fd --- /dev/null +++ b/macros/src/injectable/dependency.rs @@ -0,0 +1,81 @@ +use std::error::Error; + +use proc_macro2::Ident; +use syn::{parse2, FnArg, GenericArgument, LitStr, PathArguments, Type}; + +use crate::injectable::named_attr_input::NamedAttrInput; +use crate::util::syn_path::syn_path_to_string; + +pub struct Dependency +{ + pub interface: Type, + pub ptr: Ident, + pub name: Option<LitStr>, +} + +impl Dependency +{ + pub fn build(new_method_arg: &FnArg) -> Result<Self, Box<dyn Error>> + { + let typed_new_method_arg = match new_method_arg { + FnArg::Typed(typed_arg) => Ok(typed_arg), + FnArg::Receiver(_) => Err("Unexpected self argument in 'new' method"), + }?; + + let ptr_type_path = match typed_new_method_arg.ty.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + &_ => Err("Unexpected reference to non-path type"), + }, + &_ => Err("Expected a path or a reference type"), + }?; + + let ptr_path_segment = ptr_type_path.path.segments.last().map_or_else( + || Err("Expected pointer type path to have a last segment"), + Ok, + )?; + + let ptr = ptr_path_segment.ident.clone(); + + let ptr_path_generic_args = &match &ptr_path_segment.arguments { + PathArguments::AngleBracketed(generic_args) => Ok(generic_args), + &_ => Err("Expected pointer type to have a generic type argument"), + }? + .args; + + let interface = if let Some(GenericArgument::Type(interface)) = + ptr_path_generic_args.first() + { + Ok(interface.clone()) + } else { + Err("Expected pointer type to have a generic type argument") + }?; + + let arg_attrs = &typed_new_method_arg.attrs; + + let opt_named_attr = arg_attrs.iter().find(|attr| { + attr.path.get_ident().map_or_else( + || false, + |attr_ident| attr_ident.to_string().as_str() == "named", + ) || syn_path_to_string(&attr.path) == "syrette::named" + }); + + let opt_named_attr_tokens = opt_named_attr.map(|attr| &attr.tokens); + + let opt_named_attr_input = + if let Some(named_attr_tokens) = opt_named_attr_tokens { + Some(parse2::<NamedAttrInput>(named_attr_tokens.clone()).map_err( + |err| format!("Invalid input for 'named' attribute. {}", err), + )?) + } else { + None + }; + + Ok(Self { + interface, + ptr, + name: opt_named_attr_input.map(|named_attr_input| named_attr_input.name), + }) + } +} diff --git a/macros/src/injectable/implementation.rs b/macros/src/injectable/implementation.rs new file mode 100644 index 0000000..a84e798 --- /dev/null +++ b/macros/src/injectable/implementation.rs @@ -0,0 +1,261 @@ +use std::error::Error; + +use quote::{format_ident, quote, ToTokens}; +use syn::parse::{Parse, ParseStream}; +use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; + +use crate::injectable::dependency::Dependency; +use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase; +use crate::util::syn_path::syn_path_to_string; + +const DI_CONTAINER_VAR_NAME: &str = "di_container"; +const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history"; + +pub struct InjectableImpl +{ + pub dependencies: Vec<Dependency>, + pub self_type: Type, + pub generics: Generics, + pub original_impl: ItemImpl, +} + +impl Parse for InjectableImpl +{ + fn parse(input: ParseStream) -> syn::Result<Self> + { + let mut impl_parsed_input = input.parse::<ItemImpl>()?; + + let dependencies = Self::build_dependencies(&mut impl_parsed_input) + .map_err(|err| input.error(err))?; + + Ok(Self { + dependencies, + self_type: impl_parsed_input.self_ty.as_ref().clone(), + generics: impl_parsed_input.generics.clone(), + original_impl: impl_parsed_input, + }) + } +} + +impl InjectableImpl +{ + pub fn expand(&self, no_doc_hidden: bool, is_async: bool) + -> proc_macro2::TokenStream + { + let Self { + dependencies, + self_type, + generics, + original_impl, + } = self; + + let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); + let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); + + let maybe_doc_hidden = if no_doc_hidden { + quote! {} + } else { + quote! { + #[doc(hidden)] + } + }; + + let maybe_prevent_circular_deps = if cfg!(feature = "prevent-circular") { + quote! { + if #dependency_history_var.contains(&self_type_name) { + #dependency_history_var.push(self_type_name); + + let dependency_trace = + syrette::dependency_trace::create_dependency_trace( + #dependency_history_var.as_slice(), + self_type_name + ); + + return Err(InjectableError::DetectedCircular {dependency_trace }); + } + + #dependency_history_var.push(self_type_name); + } + } else { + quote! {} + }; + + let injectable_impl = if is_async { + let async_get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, true); + + quote! { + #maybe_doc_hidden + #[syrette::libs::async_trait::async_trait] + impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type + { + async fn resolve( + #di_container_var: &std::sync::Arc<syrette::async_di_container::AsyncDIContainer>, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr<Self>, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#async_get_dep_method_calls),* + ))); + } + } + + } + } else { + let get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, false); + + quote! { + #maybe_doc_hidden + impl #generics syrette::interfaces::injectable::Injectable for #self_type + { + fn resolve( + #di_container_var: &std::rc::Rc<syrette::DIContainer>, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr<Self>, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#get_dep_method_calls),* + ))); + } + } + } + }; + + quote! { + #original_impl + + #injectable_impl + } + } + + fn create_get_dep_method_calls( + dependencies: &[Dependency], + is_async: bool, + ) -> Vec<proc_macro2::TokenStream> + { + dependencies + .iter() + .filter_map(|dependency| { + let dep_interface_str = match &dependency.interface { + Type::TraitObject(interface_trait) => { + Some(interface_trait.to_token_stream().to_string()) + } + Type::Path(path_interface) => { + Some(syn_path_to_string(&path_interface.path)) + } + &_ => None, + }?; + + let method_call = parse_str::<ExprMethodCall>( + format!( + "{}.get_bound::<{}>({}.clone(), {})", + DI_CONTAINER_VAR_NAME, + dep_interface_str, + DEPENDENCY_HISTORY_VAR_NAME, + dependency.name.as_ref().map_or_else( + || "None".to_string(), + |name| format!("Some(\"{}\")", name.value()) + ) + ) + .as_str(), + ) + .ok()?; + + Some((method_call, dependency)) + }) + .map(|(method_call, dep_type)| { + let ptr_name = dep_type.ptr.to_string(); + + let to_ptr = format_ident!( + "{}", + camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) + ); + + let do_method_call = if is_async { + quote! { #method_call.await } + } else { + quote! { #method_call } + }; + + let resolve_failed_error = if is_async { + quote! { InjectableError::AsyncResolveFailed } + } else { + quote! { InjectableError::ResolveFailed } + }; + + quote! { + #do_method_call.map_err(|err| #resolve_failed_error { + reason: Box::new(err), + affected: self_type_name + })?.#to_ptr().unwrap() + } + }) + .collect() + } + + fn build_dependencies( + item_impl: &mut ItemImpl, + ) -> Result<Vec<Dependency>, Box<dyn Error>> + { + let new_method_impl_item = find_impl_method_by_name_mut(item_impl, "new") + .map_or_else(|| Err("Missing a 'new' method"), Ok)?; + + let new_method_args = &mut new_method_impl_item.sig.inputs; + + let dependencies: Result<Vec<_>, _> = + new_method_args.iter().map(Dependency::build).collect(); + + for arg in new_method_args { + let typed_arg = if let FnArg::Typed(typed_arg) = arg { + typed_arg + } else { + continue; + }; + + let attrs_to_remove: Vec<_> = typed_arg + .attrs + .iter() + .enumerate() + .filter_map(|(index, attr)| { + if syn_path_to_string(&attr.path).as_str() == "syrette::named" { + return Some(index); + } + + if attr.path.get_ident()?.to_string().as_str() == "named" { + return Some(index); + } + + None + }) + .collect(); + + for attr_index in attrs_to_remove { + typed_arg.attrs.remove(attr_index); + } + } + + dependencies + } +} diff --git a/macros/src/injectable/macro_args.rs b/macros/src/injectable/macro_args.rs new file mode 100644 index 0000000..50d4087 --- /dev/null +++ b/macros/src/injectable/macro_args.rs @@ -0,0 +1,67 @@ +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{Token, TypePath}; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"]; + +pub struct InjectableMacroArgs +{ + pub interface: Option<TypePath>, + pub flags: Punctuated<MacroFlag, Token![,]>, +} + +impl Parse for InjectableMacroArgs +{ + fn parse(input: ParseStream) -> syn::Result<Self> + { + let interface = input.parse::<TypePath>().ok(); + + if interface.is_some() { + let comma_input_lookahead = input.lookahead1(); + + if !comma_input_lookahead.peek(Token![,]) { + return Ok(Self { + interface, + flags: Punctuated::new(), + }); + } + + input.parse::<Token![,]>()?; + } + + if input.is_empty() { + return Ok(Self { + interface, + flags: Punctuated::new(), + }); + } + + let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + INJECTABLE_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::<Vec<_>>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { interface, flags }) + } +} diff --git a/macros/src/injectable/mod.rs b/macros/src/injectable/mod.rs new file mode 100644 index 0000000..b713aeb --- /dev/null +++ b/macros/src/injectable/mod.rs @@ -0,0 +1,4 @@ +pub mod dependency; +pub mod implementation; +pub mod macro_args; +pub mod named_attr_input; diff --git a/macros/src/injectable/named_attr_input.rs b/macros/src/injectable/named_attr_input.rs new file mode 100644 index 0000000..5f7123c --- /dev/null +++ b/macros/src/injectable/named_attr_input.rs @@ -0,0 +1,21 @@ +use syn::parse::Parse; +use syn::{parenthesized, LitStr}; + +pub struct NamedAttrInput +{ + pub name: LitStr, +} + +impl Parse for NamedAttrInput +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> + { + let content; + + parenthesized!(content in input); + + Ok(Self { + name: content.parse()?, + }) + } +} |