From 695f90bf900015df1e2728445f833dabced838a9 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 24 Sep 2022 13:13:20 +0200 Subject: refactor: reorganize modules in the macros crate --- macros/src/decl_def_factory_args.rs | 56 ------- macros/src/dependency.rs | 81 --------- macros/src/factory/declare_default_args.rs | 56 +++++++ macros/src/factory/macro_args.rs | 44 +++++ macros/src/factory/mod.rs | 3 + macros/src/factory/type_alias.rs | 35 ++++ macros/src/factory_macro_args.rs | 44 ----- macros/src/factory_type_alias.rs | 35 ---- macros/src/injectable/dependency.rs | 81 +++++++++ macros/src/injectable/implementation.rs | 261 +++++++++++++++++++++++++++++ macros/src/injectable/macro_args.rs | 67 ++++++++ macros/src/injectable/mod.rs | 4 + macros/src/injectable/named_attr_input.rs | 21 +++ macros/src/injectable_impl.rs | 261 ----------------------------- macros/src/injectable_macro_args.rs | 67 -------- macros/src/lib.rs | 24 ++- macros/src/named_attr_input.rs | 21 --- 17 files changed, 582 insertions(+), 579 deletions(-) delete mode 100644 macros/src/decl_def_factory_args.rs delete mode 100644 macros/src/dependency.rs create mode 100644 macros/src/factory/declare_default_args.rs create mode 100644 macros/src/factory/macro_args.rs create mode 100644 macros/src/factory/mod.rs create mode 100644 macros/src/factory/type_alias.rs delete mode 100644 macros/src/factory_macro_args.rs delete mode 100644 macros/src/factory_type_alias.rs create mode 100644 macros/src/injectable/dependency.rs create mode 100644 macros/src/injectable/implementation.rs create mode 100644 macros/src/injectable/macro_args.rs create mode 100644 macros/src/injectable/mod.rs create mode 100644 macros/src/injectable/named_attr_input.rs delete mode 100644 macros/src/injectable_impl.rs delete mode 100644 macros/src/injectable_macro_args.rs delete mode 100644 macros/src/named_attr_input.rs diff --git a/macros/src/decl_def_factory_args.rs b/macros/src/decl_def_factory_args.rs deleted file mode 100644 index 6450583..0000000 --- a/macros/src/decl_def_factory_args.rs +++ /dev/null @@ -1,56 +0,0 @@ -use syn::parse::Parse; -use syn::punctuated::Punctuated; -use syn::{Token, Type}; - -use crate::macro_flag::MacroFlag; -use crate::util::iterator_ext::IteratorExt; - -pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"]; - -pub struct DeclareDefaultFactoryMacroArgs -{ - pub interface: Type, - pub flags: Punctuated, -} - -impl Parse for DeclareDefaultFactoryMacroArgs -{ - fn parse(input: syn::parse::ParseStream) -> syn::Result - { - let interface = input.parse().unwrap(); - - if !input.peek(Token![,]) { - return Ok(Self { - interface, - flags: Punctuated::new(), - }); - } - - input.parse::().unwrap(); - - let flags = Punctuated::::parse_terminated(input)?; - - for flag in &flags { - let flag_str = flag.flag.to_string(); - - if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { - return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, - FACTORY_MACRO_FLAGS.join(",") - ))); - } - } - - let flag_names = flags - .iter() - .map(|flag| flag.flag.to_string()) - .collect::>(); - - if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { - return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); - } - - Ok(Self { interface, flags }) - } -} diff --git a/macros/src/dependency.rs b/macros/src/dependency.rs deleted file mode 100644 index d20af90..0000000 --- a/macros/src/dependency.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::error::Error; - -use proc_macro2::Ident; -use syn::{parse2, FnArg, GenericArgument, LitStr, PathArguments, Type}; - -use crate::named_attr_input::NamedAttrInput; -use crate::util::syn_path::syn_path_to_string; - -pub struct Dependency -{ - pub interface: Type, - pub ptr: Ident, - pub name: Option, -} - -impl Dependency -{ - pub fn build(new_method_arg: &FnArg) -> Result> - { - let typed_new_method_arg = match new_method_arg { - FnArg::Typed(typed_arg) => Ok(typed_arg), - FnArg::Receiver(_) => Err("Unexpected self argument in 'new' method"), - }?; - - let ptr_type_path = match typed_new_method_arg.ty.as_ref() { - Type::Path(arg_type_path) => Ok(arg_type_path), - Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { - Type::Path(arg_type_path) => Ok(arg_type_path), - &_ => Err("Unexpected reference to non-path type"), - }, - &_ => Err("Expected a path or a reference type"), - }?; - - let ptr_path_segment = ptr_type_path.path.segments.last().map_or_else( - || Err("Expected pointer type path to have a last segment"), - Ok, - )?; - - let ptr = ptr_path_segment.ident.clone(); - - let ptr_path_generic_args = &match &ptr_path_segment.arguments { - PathArguments::AngleBracketed(generic_args) => Ok(generic_args), - &_ => Err("Expected pointer type to have a generic type argument"), - }? - .args; - - let interface = if let Some(GenericArgument::Type(interface)) = - ptr_path_generic_args.first() - { - Ok(interface.clone()) - } else { - Err("Expected pointer type to have a generic type argument") - }?; - - let arg_attrs = &typed_new_method_arg.attrs; - - let opt_named_attr = arg_attrs.iter().find(|attr| { - attr.path.get_ident().map_or_else( - || false, - |attr_ident| attr_ident.to_string().as_str() == "named", - ) || syn_path_to_string(&attr.path) == "syrette::named" - }); - - let opt_named_attr_tokens = opt_named_attr.map(|attr| &attr.tokens); - - let opt_named_attr_input = - if let Some(named_attr_tokens) = opt_named_attr_tokens { - Some(parse2::(named_attr_tokens.clone()).map_err( - |err| format!("Invalid input for 'named' attribute. {}", err), - )?) - } else { - None - }; - - Ok(Self { - interface, - ptr, - name: opt_named_attr_input.map(|named_attr_input| named_attr_input.name), - }) - } -} diff --git a/macros/src/factory/declare_default_args.rs b/macros/src/factory/declare_default_args.rs new file mode 100644 index 0000000..6450583 --- /dev/null +++ b/macros/src/factory/declare_default_args.rs @@ -0,0 +1,56 @@ +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::{Token, Type}; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"]; + +pub struct DeclareDefaultFactoryMacroArgs +{ + pub interface: Type, + pub flags: Punctuated, +} + +impl Parse for DeclareDefaultFactoryMacroArgs +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result + { + let interface = input.parse().unwrap(); + + if !input.peek(Token![,]) { + return Ok(Self { + interface, + flags: Punctuated::new(), + }); + } + + input.parse::().unwrap(); + + let flags = Punctuated::::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + FACTORY_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { interface, flags }) + } +} diff --git a/macros/src/factory/macro_args.rs b/macros/src/factory/macro_args.rs new file mode 100644 index 0000000..dd80c1c --- /dev/null +++ b/macros/src/factory/macro_args.rs @@ -0,0 +1,44 @@ +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::Token; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe", "async"]; + +pub struct FactoryMacroArgs +{ + pub flags: Punctuated, +} + +impl Parse for FactoryMacroArgs +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result + { + let flags = Punctuated::::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + FACTORY_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { flags }) + } +} diff --git a/macros/src/factory/mod.rs b/macros/src/factory/mod.rs new file mode 100644 index 0000000..a8947c5 --- /dev/null +++ b/macros/src/factory/mod.rs @@ -0,0 +1,3 @@ +pub mod declare_default_args; +pub mod macro_args; +pub mod type_alias; diff --git a/macros/src/factory/type_alias.rs b/macros/src/factory/type_alias.rs new file mode 100644 index 0000000..64afe57 --- /dev/null +++ b/macros/src/factory/type_alias.rs @@ -0,0 +1,35 @@ +use quote::ToTokens; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{parse, ItemType, Token, Type}; + +use crate::fn_trait::FnTrait; + +pub struct FactoryTypeAlias +{ + pub type_alias: ItemType, + pub factory_interface: FnTrait, + pub arg_types: Punctuated, + pub return_type: Type, +} + +impl Parse for FactoryTypeAlias +{ + fn parse(input: ParseStream) -> syn::Result + { + let type_alias = match input.parse::() { + Ok(type_alias) => Ok(type_alias), + Err(_) => Err(input.error("Expected a type alias")), + }?; + + let aliased_fn_trait = + parse::(type_alias.ty.as_ref().to_token_stream().into())?; + + Ok(Self { + type_alias, + factory_interface: aliased_fn_trait.clone(), + arg_types: aliased_fn_trait.inputs, + return_type: aliased_fn_trait.output, + }) + } +} diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs deleted file mode 100644 index dd80c1c..0000000 --- a/macros/src/factory_macro_args.rs +++ /dev/null @@ -1,44 +0,0 @@ -use syn::parse::Parse; -use syn::punctuated::Punctuated; -use syn::Token; - -use crate::macro_flag::MacroFlag; -use crate::util::iterator_ext::IteratorExt; - -pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe", "async"]; - -pub struct FactoryMacroArgs -{ - pub flags: Punctuated, -} - -impl Parse for FactoryMacroArgs -{ - fn parse(input: syn::parse::ParseStream) -> syn::Result - { - let flags = Punctuated::::parse_terminated(input)?; - - for flag in &flags { - let flag_str = flag.flag.to_string(); - - if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { - return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, - FACTORY_MACRO_FLAGS.join(",") - ))); - } - } - - let flag_names = flags - .iter() - .map(|flag| flag.flag.to_string()) - .collect::>(); - - if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { - return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); - } - - Ok(Self { flags }) - } -} diff --git a/macros/src/factory_type_alias.rs b/macros/src/factory_type_alias.rs deleted file mode 100644 index 64afe57..0000000 --- a/macros/src/factory_type_alias.rs +++ /dev/null @@ -1,35 +0,0 @@ -use quote::ToTokens; -use syn::parse::{Parse, ParseStream}; -use syn::punctuated::Punctuated; -use syn::{parse, ItemType, Token, Type}; - -use crate::fn_trait::FnTrait; - -pub struct FactoryTypeAlias -{ - pub type_alias: ItemType, - pub factory_interface: FnTrait, - pub arg_types: Punctuated, - pub return_type: Type, -} - -impl Parse for FactoryTypeAlias -{ - fn parse(input: ParseStream) -> syn::Result - { - let type_alias = match input.parse::() { - Ok(type_alias) => Ok(type_alias), - Err(_) => Err(input.error("Expected a type alias")), - }?; - - let aliased_fn_trait = - parse::(type_alias.ty.as_ref().to_token_stream().into())?; - - Ok(Self { - type_alias, - factory_interface: aliased_fn_trait.clone(), - arg_types: aliased_fn_trait.inputs, - return_type: aliased_fn_trait.output, - }) - } -} diff --git a/macros/src/injectable/dependency.rs b/macros/src/injectable/dependency.rs new file mode 100644 index 0000000..2c5e0fd --- /dev/null +++ b/macros/src/injectable/dependency.rs @@ -0,0 +1,81 @@ +use std::error::Error; + +use proc_macro2::Ident; +use syn::{parse2, FnArg, GenericArgument, LitStr, PathArguments, Type}; + +use crate::injectable::named_attr_input::NamedAttrInput; +use crate::util::syn_path::syn_path_to_string; + +pub struct Dependency +{ + pub interface: Type, + pub ptr: Ident, + pub name: Option, +} + +impl Dependency +{ + pub fn build(new_method_arg: &FnArg) -> Result> + { + let typed_new_method_arg = match new_method_arg { + FnArg::Typed(typed_arg) => Ok(typed_arg), + FnArg::Receiver(_) => Err("Unexpected self argument in 'new' method"), + }?; + + let ptr_type_path = match typed_new_method_arg.ty.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + &_ => Err("Unexpected reference to non-path type"), + }, + &_ => Err("Expected a path or a reference type"), + }?; + + let ptr_path_segment = ptr_type_path.path.segments.last().map_or_else( + || Err("Expected pointer type path to have a last segment"), + Ok, + )?; + + let ptr = ptr_path_segment.ident.clone(); + + let ptr_path_generic_args = &match &ptr_path_segment.arguments { + PathArguments::AngleBracketed(generic_args) => Ok(generic_args), + &_ => Err("Expected pointer type to have a generic type argument"), + }? + .args; + + let interface = if let Some(GenericArgument::Type(interface)) = + ptr_path_generic_args.first() + { + Ok(interface.clone()) + } else { + Err("Expected pointer type to have a generic type argument") + }?; + + let arg_attrs = &typed_new_method_arg.attrs; + + let opt_named_attr = arg_attrs.iter().find(|attr| { + attr.path.get_ident().map_or_else( + || false, + |attr_ident| attr_ident.to_string().as_str() == "named", + ) || syn_path_to_string(&attr.path) == "syrette::named" + }); + + let opt_named_attr_tokens = opt_named_attr.map(|attr| &attr.tokens); + + let opt_named_attr_input = + if let Some(named_attr_tokens) = opt_named_attr_tokens { + Some(parse2::(named_attr_tokens.clone()).map_err( + |err| format!("Invalid input for 'named' attribute. {}", err), + )?) + } else { + None + }; + + Ok(Self { + interface, + ptr, + name: opt_named_attr_input.map(|named_attr_input| named_attr_input.name), + }) + } +} diff --git a/macros/src/injectable/implementation.rs b/macros/src/injectable/implementation.rs new file mode 100644 index 0000000..a84e798 --- /dev/null +++ b/macros/src/injectable/implementation.rs @@ -0,0 +1,261 @@ +use std::error::Error; + +use quote::{format_ident, quote, ToTokens}; +use syn::parse::{Parse, ParseStream}; +use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; + +use crate::injectable::dependency::Dependency; +use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase; +use crate::util::syn_path::syn_path_to_string; + +const DI_CONTAINER_VAR_NAME: &str = "di_container"; +const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history"; + +pub struct InjectableImpl +{ + pub dependencies: Vec, + pub self_type: Type, + pub generics: Generics, + pub original_impl: ItemImpl, +} + +impl Parse for InjectableImpl +{ + fn parse(input: ParseStream) -> syn::Result + { + let mut impl_parsed_input = input.parse::()?; + + let dependencies = Self::build_dependencies(&mut impl_parsed_input) + .map_err(|err| input.error(err))?; + + Ok(Self { + dependencies, + self_type: impl_parsed_input.self_ty.as_ref().clone(), + generics: impl_parsed_input.generics.clone(), + original_impl: impl_parsed_input, + }) + } +} + +impl InjectableImpl +{ + pub fn expand(&self, no_doc_hidden: bool, is_async: bool) + -> proc_macro2::TokenStream + { + let Self { + dependencies, + self_type, + generics, + original_impl, + } = self; + + let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); + let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); + + let maybe_doc_hidden = if no_doc_hidden { + quote! {} + } else { + quote! { + #[doc(hidden)] + } + }; + + let maybe_prevent_circular_deps = if cfg!(feature = "prevent-circular") { + quote! { + if #dependency_history_var.contains(&self_type_name) { + #dependency_history_var.push(self_type_name); + + let dependency_trace = + syrette::dependency_trace::create_dependency_trace( + #dependency_history_var.as_slice(), + self_type_name + ); + + return Err(InjectableError::DetectedCircular {dependency_trace }); + } + + #dependency_history_var.push(self_type_name); + } + } else { + quote! {} + }; + + let injectable_impl = if is_async { + let async_get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, true); + + quote! { + #maybe_doc_hidden + #[syrette::libs::async_trait::async_trait] + impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type + { + async fn resolve( + #di_container_var: &std::sync::Arc, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#async_get_dep_method_calls),* + ))); + } + } + + } + } else { + let get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, false); + + quote! { + #maybe_doc_hidden + impl #generics syrette::interfaces::injectable::Injectable for #self_type + { + fn resolve( + #di_container_var: &std::rc::Rc, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#get_dep_method_calls),* + ))); + } + } + } + }; + + quote! { + #original_impl + + #injectable_impl + } + } + + fn create_get_dep_method_calls( + dependencies: &[Dependency], + is_async: bool, + ) -> Vec + { + dependencies + .iter() + .filter_map(|dependency| { + let dep_interface_str = match &dependency.interface { + Type::TraitObject(interface_trait) => { + Some(interface_trait.to_token_stream().to_string()) + } + Type::Path(path_interface) => { + Some(syn_path_to_string(&path_interface.path)) + } + &_ => None, + }?; + + let method_call = parse_str::( + format!( + "{}.get_bound::<{}>({}.clone(), {})", + DI_CONTAINER_VAR_NAME, + dep_interface_str, + DEPENDENCY_HISTORY_VAR_NAME, + dependency.name.as_ref().map_or_else( + || "None".to_string(), + |name| format!("Some(\"{}\")", name.value()) + ) + ) + .as_str(), + ) + .ok()?; + + Some((method_call, dependency)) + }) + .map(|(method_call, dep_type)| { + let ptr_name = dep_type.ptr.to_string(); + + let to_ptr = format_ident!( + "{}", + camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) + ); + + let do_method_call = if is_async { + quote! { #method_call.await } + } else { + quote! { #method_call } + }; + + let resolve_failed_error = if is_async { + quote! { InjectableError::AsyncResolveFailed } + } else { + quote! { InjectableError::ResolveFailed } + }; + + quote! { + #do_method_call.map_err(|err| #resolve_failed_error { + reason: Box::new(err), + affected: self_type_name + })?.#to_ptr().unwrap() + } + }) + .collect() + } + + fn build_dependencies( + item_impl: &mut ItemImpl, + ) -> Result, Box> + { + let new_method_impl_item = find_impl_method_by_name_mut(item_impl, "new") + .map_or_else(|| Err("Missing a 'new' method"), Ok)?; + + let new_method_args = &mut new_method_impl_item.sig.inputs; + + let dependencies: Result, _> = + new_method_args.iter().map(Dependency::build).collect(); + + for arg in new_method_args { + let typed_arg = if let FnArg::Typed(typed_arg) = arg { + typed_arg + } else { + continue; + }; + + let attrs_to_remove: Vec<_> = typed_arg + .attrs + .iter() + .enumerate() + .filter_map(|(index, attr)| { + if syn_path_to_string(&attr.path).as_str() == "syrette::named" { + return Some(index); + } + + if attr.path.get_ident()?.to_string().as_str() == "named" { + return Some(index); + } + + None + }) + .collect(); + + for attr_index in attrs_to_remove { + typed_arg.attrs.remove(attr_index); + } + } + + dependencies + } +} diff --git a/macros/src/injectable/macro_args.rs b/macros/src/injectable/macro_args.rs new file mode 100644 index 0000000..50d4087 --- /dev/null +++ b/macros/src/injectable/macro_args.rs @@ -0,0 +1,67 @@ +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{Token, TypePath}; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"]; + +pub struct InjectableMacroArgs +{ + pub interface: Option, + pub flags: Punctuated, +} + +impl Parse for InjectableMacroArgs +{ + fn parse(input: ParseStream) -> syn::Result + { + let interface = input.parse::().ok(); + + if interface.is_some() { + let comma_input_lookahead = input.lookahead1(); + + if !comma_input_lookahead.peek(Token![,]) { + return Ok(Self { + interface, + flags: Punctuated::new(), + }); + } + + input.parse::()?; + } + + if input.is_empty() { + return Ok(Self { + interface, + flags: Punctuated::new(), + }); + } + + let flags = Punctuated::::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + INJECTABLE_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { interface, flags }) + } +} diff --git a/macros/src/injectable/mod.rs b/macros/src/injectable/mod.rs new file mode 100644 index 0000000..b713aeb --- /dev/null +++ b/macros/src/injectable/mod.rs @@ -0,0 +1,4 @@ +pub mod dependency; +pub mod implementation; +pub mod macro_args; +pub mod named_attr_input; diff --git a/macros/src/injectable/named_attr_input.rs b/macros/src/injectable/named_attr_input.rs new file mode 100644 index 0000000..5f7123c --- /dev/null +++ b/macros/src/injectable/named_attr_input.rs @@ -0,0 +1,21 @@ +use syn::parse::Parse; +use syn::{parenthesized, LitStr}; + +pub struct NamedAttrInput +{ + pub name: LitStr, +} + +impl Parse for NamedAttrInput +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result + { + let content; + + parenthesized!(content in input); + + Ok(Self { + name: content.parse()?, + }) + } +} diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs deleted file mode 100644 index bf5c96c..0000000 --- a/macros/src/injectable_impl.rs +++ /dev/null @@ -1,261 +0,0 @@ -use std::error::Error; - -use quote::{format_ident, quote, ToTokens}; -use syn::parse::{Parse, ParseStream}; -use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; - -use crate::dependency::Dependency; -use crate::util::item_impl::find_impl_method_by_name_mut; -use crate::util::string::camelcase_to_snakecase; -use crate::util::syn_path::syn_path_to_string; - -const DI_CONTAINER_VAR_NAME: &str = "di_container"; -const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history"; - -pub struct InjectableImpl -{ - pub dependencies: Vec, - pub self_type: Type, - pub generics: Generics, - pub original_impl: ItemImpl, -} - -impl Parse for InjectableImpl -{ - fn parse(input: ParseStream) -> syn::Result - { - let mut impl_parsed_input = input.parse::()?; - - let dependencies = Self::build_dependencies(&mut impl_parsed_input) - .map_err(|err| input.error(err))?; - - Ok(Self { - dependencies, - self_type: impl_parsed_input.self_ty.as_ref().clone(), - generics: impl_parsed_input.generics.clone(), - original_impl: impl_parsed_input, - }) - } -} - -impl InjectableImpl -{ - pub fn expand(&self, no_doc_hidden: bool, is_async: bool) - -> proc_macro2::TokenStream - { - let Self { - dependencies, - self_type, - generics, - original_impl, - } = self; - - let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); - let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - - let maybe_doc_hidden = if no_doc_hidden { - quote! {} - } else { - quote! { - #[doc(hidden)] - } - }; - - let maybe_prevent_circular_deps = if cfg!(feature = "prevent-circular") { - quote! { - if #dependency_history_var.contains(&self_type_name) { - #dependency_history_var.push(self_type_name); - - let dependency_trace = - syrette::dependency_trace::create_dependency_trace( - #dependency_history_var.as_slice(), - self_type_name - ); - - return Err(InjectableError::DetectedCircular {dependency_trace }); - } - - #dependency_history_var.push(self_type_name); - } - } else { - quote! {} - }; - - let injectable_impl = if is_async { - let async_get_dep_method_calls = - Self::create_get_dep_method_calls(dependencies, true); - - quote! { - #maybe_doc_hidden - #[syrette::libs::async_trait::async_trait] - impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type - { - async fn resolve( - #di_container_var: &std::sync::Arc, - mut #dependency_history_var: Vec<&'static str>, - ) -> Result< - syrette::ptr::TransientPtr, - syrette::errors::injectable::InjectableError> - { - use std::any::type_name; - - use syrette::errors::injectable::InjectableError; - - let self_type_name = type_name::<#self_type>(); - - #maybe_prevent_circular_deps - - return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#async_get_dep_method_calls),* - ))); - } - } - - } - } else { - let get_dep_method_calls = - Self::create_get_dep_method_calls(dependencies, false); - - quote! { - #maybe_doc_hidden - impl #generics syrette::interfaces::injectable::Injectable for #self_type - { - fn resolve( - #di_container_var: &std::rc::Rc, - mut #dependency_history_var: Vec<&'static str>, - ) -> Result< - syrette::ptr::TransientPtr, - syrette::errors::injectable::InjectableError> - { - use std::any::type_name; - - use syrette::errors::injectable::InjectableError; - - let self_type_name = type_name::<#self_type>(); - - #maybe_prevent_circular_deps - - return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dep_method_calls),* - ))); - } - } - } - }; - - quote! { - #original_impl - - #injectable_impl - } - } - - fn create_get_dep_method_calls( - dependencies: &[Dependency], - is_async: bool, - ) -> Vec - { - dependencies - .iter() - .filter_map(|dependency| { - let dep_interface_str = match &dependency.interface { - Type::TraitObject(interface_trait) => { - Some(interface_trait.to_token_stream().to_string()) - } - Type::Path(path_interface) => { - Some(syn_path_to_string(&path_interface.path)) - } - &_ => None, - }?; - - let method_call = parse_str::( - format!( - "{}.get_bound::<{}>({}.clone(), {})", - DI_CONTAINER_VAR_NAME, - dep_interface_str, - DEPENDENCY_HISTORY_VAR_NAME, - dependency.name.as_ref().map_or_else( - || "None".to_string(), - |name| format!("Some(\"{}\")", name.value()) - ) - ) - .as_str(), - ) - .ok()?; - - Some((method_call, dependency)) - }) - .map(|(method_call, dep_type)| { - let ptr_name = dep_type.ptr.to_string(); - - let to_ptr = format_ident!( - "{}", - camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) - ); - - let do_method_call = if is_async { - quote! { #method_call.await } - } else { - quote! { #method_call } - }; - - let resolve_failed_error = if is_async { - quote! { InjectableError::AsyncResolveFailed } - } else { - quote! { InjectableError::ResolveFailed } - }; - - quote! { - #do_method_call.map_err(|err| #resolve_failed_error { - reason: Box::new(err), - affected: self_type_name - })?.#to_ptr().unwrap() - } - }) - .collect() - } - - fn build_dependencies( - item_impl: &mut ItemImpl, - ) -> Result, Box> - { - let new_method_impl_item = find_impl_method_by_name_mut(item_impl, "new") - .map_or_else(|| Err("Missing a 'new' method"), Ok)?; - - let new_method_args = &mut new_method_impl_item.sig.inputs; - - let dependencies: Result, _> = - new_method_args.iter().map(Dependency::build).collect(); - - for arg in new_method_args { - let typed_arg = if let FnArg::Typed(typed_arg) = arg { - typed_arg - } else { - continue; - }; - - let attrs_to_remove: Vec<_> = typed_arg - .attrs - .iter() - .enumerate() - .filter_map(|(index, attr)| { - if syn_path_to_string(&attr.path).as_str() == "syrette::named" { - return Some(index); - } - - if attr.path.get_ident()?.to_string().as_str() == "named" { - return Some(index); - } - - None - }) - .collect(); - - for attr_index in attrs_to_remove { - typed_arg.attrs.remove(attr_index); - } - } - - dependencies - } -} diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs deleted file mode 100644 index 50d4087..0000000 --- a/macros/src/injectable_macro_args.rs +++ /dev/null @@ -1,67 +0,0 @@ -use syn::parse::{Parse, ParseStream}; -use syn::punctuated::Punctuated; -use syn::{Token, TypePath}; - -use crate::macro_flag::MacroFlag; -use crate::util::iterator_ext::IteratorExt; - -pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"]; - -pub struct InjectableMacroArgs -{ - pub interface: Option, - pub flags: Punctuated, -} - -impl Parse for InjectableMacroArgs -{ - fn parse(input: ParseStream) -> syn::Result - { - let interface = input.parse::().ok(); - - if interface.is_some() { - let comma_input_lookahead = input.lookahead1(); - - if !comma_input_lookahead.peek(Token![,]) { - return Ok(Self { - interface, - flags: Punctuated::new(), - }); - } - - input.parse::()?; - } - - if input.is_empty() { - return Ok(Self { - interface, - flags: Punctuated::new(), - }); - } - - let flags = Punctuated::::parse_terminated(input)?; - - for flag in &flags { - let flag_str = flag.flag.to_string(); - - if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { - return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, - INJECTABLE_MACRO_FLAGS.join(",") - ))); - } - } - - let flag_names = flags - .iter() - .map(|flag| flag.flag.to_string()) - .collect::>(); - - if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { - return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); - } - - Ok(Self { interface, flags }) - } -} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index b0ccc86..390d239 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -8,23 +8,18 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse, parse_macro_input, parse_str}; -mod decl_def_factory_args; mod declare_interface_args; -mod dependency; -mod factory_macro_args; -mod factory_type_alias; +mod factory; mod fn_trait; -mod injectable_impl; -mod injectable_macro_args; +mod injectable; mod libs; mod macro_flag; -mod named_attr_input; mod util; -use declare_interface_args::DeclareInterfaceArgs; -use injectable_impl::InjectableImpl; -use injectable_macro_args::InjectableMacroArgs; -use libs::intertrait_macros::gen_caster::generate_caster; +use crate::declare_interface_args::DeclareInterfaceArgs; +use crate::injectable::implementation::InjectableImpl; +use crate::injectable::macro_args::InjectableMacroArgs; +use crate::libs::intertrait_macros::gen_caster::generate_caster; /// Makes a struct injectable. Thereby usable with [`DIContainer`]. /// @@ -195,7 +190,8 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke use quote::ToTokens; use syn::Type; - use crate::factory_macro_args::FactoryMacroArgs; + use crate::factory::macro_args::FactoryMacroArgs; + use crate::factory::type_alias::FactoryTypeAlias; let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); @@ -213,7 +209,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke is_threadsafe = true; } - let factory_type_alias::FactoryTypeAlias { + let FactoryTypeAlias { mut type_alias, mut factory_interface, arg_types: _, @@ -327,7 +323,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke #[cfg(feature = "factory")] pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream { - use crate::decl_def_factory_args::DeclareDefaultFactoryMacroArgs; + use crate::factory::declare_default_args::DeclareDefaultFactoryMacroArgs; let DeclareDefaultFactoryMacroArgs { interface, flags } = parse(args_stream).unwrap(); diff --git a/macros/src/named_attr_input.rs b/macros/src/named_attr_input.rs deleted file mode 100644 index 5f7123c..0000000 --- a/macros/src/named_attr_input.rs +++ /dev/null @@ -1,21 +0,0 @@ -use syn::parse::Parse; -use syn::{parenthesized, LitStr}; - -pub struct NamedAttrInput -{ - pub name: LitStr, -} - -impl Parse for NamedAttrInput -{ - fn parse(input: syn::parse::ParseStream) -> syn::Result - { - let content; - - parenthesized!(content in input); - - Ok(Self { - name: content.parse()?, - }) - } -} -- cgit v1.2.3-18-g5258