From 080cc42bb1da09059dbc35049a7ded0649961e0c Mon Sep 17 00:00:00 2001 From: HampusM Date: Mon, 29 Aug 2022 20:52:56 +0200 Subject: feat: implement async functionality --- macros/Cargo.toml | 2 + macros/src/declare_interface_args.rs | 43 +++++++++- macros/src/factory_macro_args.rs | 44 ++++++++++ macros/src/injectable_impl.rs | 102 +++++++++++++++++----- macros/src/injectable_macro_args.rs | 55 ++++-------- macros/src/lib.rs | 108 +++++++++++++++++++----- macros/src/libs/intertrait_macros/gen_caster.rs | 26 ++++-- macros/src/macro_flag.rs | 27 ++++++ macros/src/util/mod.rs | 1 + macros/src/util/string.rs | 12 +++ 10 files changed, 330 insertions(+), 90 deletions(-) create mode 100644 macros/src/factory_macro_args.rs create mode 100644 macros/src/macro_flag.rs create mode 100644 macros/src/util/string.rs (limited to 'macros') diff --git a/macros/Cargo.toml b/macros/Cargo.toml index a929b08..28cb4c0 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -23,6 +23,8 @@ syn = { version = "1.0.96", features = ["full"] } quote = "1.0.18" proc-macro2 = "1.0.40" uuid = { version = "0.8", features = ["v4"] } +regex = "1.6.0" +once_cell = "1.13.1" [dev_dependencies] syrette = { version = "0.3.0", path = "..", features = ["factory"] } diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs index b54f458..bd2f24e 100644 --- a/macros/src/declare_interface_args.rs +++ b/macros/src/declare_interface_args.rs @@ -1,10 +1,17 @@ use syn::parse::{Parse, ParseStream, Result}; +use syn::punctuated::Punctuated; use syn::{Path, Token, Type}; +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const DECLARE_INTERFACE_FLAGS: &[&str] = &["async"]; + pub struct DeclareInterfaceArgs { pub implementation: Type, pub interface: Path, + pub flags: Punctuated, } impl Parse for DeclareInterfaceArgs @@ -15,9 +22,43 @@ impl Parse for DeclareInterfaceArgs input.parse::]>()?; + let interface: Path = input.parse()?; + + let flags = if input.peek(Token![,]) { + input.parse::()?; + + let flags = Punctuated::::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !DECLARE_INTERFACE_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + DECLARE_INTERFACE_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + flags + } else { + Punctuated::new() + }; + Ok(Self { implementation, - interface: input.parse()?, + interface, + flags, }) } } diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs new file mode 100644 index 0000000..57517d6 --- /dev/null +++ b/macros/src/factory_macro_args.rs @@ -0,0 +1,44 @@ +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::Token; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const FACTORY_MACRO_FLAGS: &[&str] = &["async"]; + +pub struct FactoryMacroArgs +{ + pub flags: Punctuated, +} + +impl Parse for FactoryMacroArgs +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result + { + let flags = Punctuated::::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + FACTORY_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { flags }) + } +} diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 990b148..3565ef9 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; use crate::dependency::Dependency; use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase; use crate::util::syn_path::syn_path_to_string; const DI_CONTAINER_VAR_NAME: &str = "di_container"; @@ -39,7 +40,8 @@ impl Parse for InjectableImpl impl InjectableImpl { - pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream + pub fn expand(&self, no_doc_hidden: bool, is_async: bool) + -> proc_macro2::TokenStream { let Self { dependencies, @@ -51,8 +53,6 @@ impl InjectableImpl let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies); - let maybe_doc_hidden = if no_doc_hidden { quote! {} } else { @@ -81,36 +81,78 @@ impl InjectableImpl quote! {} }; - quote! { - #original_impl + let injectable_impl = if is_async { + let async_get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, true); + + quote! { + #maybe_doc_hidden + #[syrette::libs::async_trait::async_trait] + impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type + { + async fn resolve( + #di_container_var: &syrette::async_di_container::AsyncDIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#async_get_dep_method_calls),* + ))); + } + } - #maybe_doc_hidden - impl #generics syrette::interfaces::injectable::Injectable for #self_type { - fn resolve( - #di_container_var: &syrette::DIContainer, - mut #dependency_history_var: Vec<&'static str>, - ) -> Result< - syrette::ptr::TransientPtr, - syrette::errors::injectable::InjectableError> + } + } else { + let get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, false); + + quote! { + #maybe_doc_hidden + impl #generics syrette::interfaces::injectable::Injectable for #self_type { - use std::any::type_name; + fn resolve( + #di_container_var: &syrette::DIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; - use syrette::errors::injectable::InjectableError; + use syrette::errors::injectable::InjectableError; - let self_type_name = type_name::<#self_type>(); + let self_type_name = type_name::<#self_type>(); - #maybe_prevent_circular_deps + #maybe_prevent_circular_deps - return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dep_method_calls),* - ))); + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#get_dep_method_calls),* + ))); + } } } + }; + + quote! { + #original_impl + + #injectable_impl } } fn create_get_dep_method_calls( dependencies: &[Dependency], + is_async: bool, ) -> Vec { dependencies @@ -146,11 +188,25 @@ impl InjectableImpl .map(|(method_call, dep_type)| { let ptr_name = dep_type.ptr.to_string(); - let to_ptr = - format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase()); + let to_ptr = format_ident!( + "{}", + camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) + ); + + let do_method_call = if is_async { + quote! { #method_call.await } + } else { + quote! { #method_call } + }; + + let resolve_failed_error = if is_async { + quote! { InjectableError::AsyncResolveFailed } + } else { + quote! { InjectableError::ResolveFailed } + }; quote! { - #method_call.map_err(|err| InjectableError::ResolveFailed { + #do_method_call.map_err(|err| #resolve_failed_error { reason: Box::new(err), affected: self_type_name })?.#to_ptr().unwrap() diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs index 43f8e11..6cc1d7e 100644 --- a/macros/src/injectable_macro_args.rs +++ b/macros/src/injectable_macro_args.rs @@ -1,49 +1,16 @@ use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::{braced, Ident, LitBool, Token, TypePath}; +use syn::{braced, Token, TypePath}; +use crate::macro_flag::MacroFlag; use crate::util::iterator_ext::IteratorExt; -pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden"]; - -pub struct InjectableMacroFlag -{ - pub flag: Ident, - pub is_on: LitBool, -} - -impl Parse for InjectableMacroFlag -{ - fn parse(input: ParseStream) -> syn::Result - { - let input_forked = input.fork(); - - let flag: Ident = input_forked.parse()?; - - let flag_str = flag.to_string(); - - if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { - return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, - INJECTABLE_MACRO_FLAGS.join(",") - ))); - } - - input.parse::()?; - - input.parse::()?; - - let is_on: LitBool = input.parse()?; - - Ok(Self { flag, is_on }) - } -} +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"]; pub struct InjectableMacroArgs { pub interface: Option, - pub flags: Punctuated, + pub flags: Punctuated, } impl Parse for InjectableMacroArgs @@ -76,7 +43,19 @@ impl Parse for InjectableMacroArgs braced!(braced_content in input); - let flags = braced_content.parse_terminated(InjectableMacroFlag::parse)?; + let flags = braced_content.parse_terminated(MacroFlag::parse)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + INJECTABLE_MACRO_FLAGS.join(",") + ))); + } + } let flag_names = flags .iter() diff --git a/macros/src/lib.rs b/macros/src/lib.rs index eb3a2be..40fbb53 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -2,7 +2,7 @@ #![deny(clippy::pedantic)] #![deny(missing_docs)] -//! Macros for the [Syrette](https://crates.io/crates/syrette) crate. +//! Macros for the [Sy&rette](https://crates.io/crates/syrette) crate. use proc_macro::TokenStream; use quote::quote; @@ -10,10 +10,12 @@ use syn::{parse, parse_macro_input}; mod declare_interface_args; mod dependency; +mod factory_macro_args; mod factory_type_alias; mod injectable_impl; mod injectable_macro_args; mod libs; +mod macro_flag; mod named_attr_input; mod util; @@ -31,6 +33,7 @@ use libs::intertrait_macros::gen_caster::generate_caster; /// # Flags /// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from /// documentation. +/// - `async` - Mark as async. /// /// # Panics /// If the attributed item is not a impl. @@ -107,21 +110,31 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt { let InjectableMacroArgs { interface, flags } = parse_macro_input!(args_stream); - let mut flags_iter = flags.iter(); - - let no_doc_hidden = flags_iter + let no_doc_hidden = flags + .iter() .find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden") .map_or(false, |flag| flag.is_on.value); + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + let injectable_impl: InjectableImpl = parse(impl_stream).unwrap(); - let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden); + let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async); let maybe_decl_interface = if interface.is_some() { let self_type = &injectable_impl.self_type; - quote! { - syrette::declare_interface!(#self_type -> #interface); + if is_async { + quote! { + syrette::declare_interface!(#self_type -> #interface, async = true); + } + } else { + quote! { + syrette::declare_interface!(#self_type -> #interface); + } } } else { quote! {} @@ -139,6 +152,12 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt /// /// *This macro is only available if Syrette is built with the "factory" feature.* /// +/// # Arguments +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async. +/// /// # Panics /// If the attributed item is not a type alias. /// @@ -166,8 +185,17 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt /// ``` #[proc_macro_attribute] #[cfg(feature = "factory")] -pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream { + use crate::factory_macro_args::FactoryMacroArgs; + + let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); + + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + let factory_type_alias::FactoryTypeAlias { type_alias, factory_interface, @@ -175,22 +203,46 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream return_type, } = parse(type_alias_stream).unwrap(); + let decl_interfaces = if is_async { + quote! { + syrette::declare_interface!( + syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< + #arg_types, + #return_type + > -> #factory_interface, + async = true + ); + + syrette::declare_interface!( + syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< + #arg_types, + #return_type + > -> syrette::interfaces::any_factory::AnyThreadsafeFactory, + async = true + ) + } + } else { + quote! { + syrette::declare_interface!( + syrette::castable_factory::blocking::CastableFactory< + #arg_types, + #return_type + > -> #factory_interface + ); + + syrette::declare_interface!( + syrette::castable_factory::blocking::CastableFactory< + #arg_types, + #return_type + > -> syrette::interfaces::any_factory::AnyFactory + ); + } + }; + quote! { #type_alias - syrette::declare_interface!( - syrette::castable_factory::CastableFactory< - #arg_types, - #return_type - > -> #factory_interface - ); - - syrette::declare_interface!( - syrette::castable_factory::CastableFactory< - #arg_types, - #return_type - > -> syrette::interfaces::any_factory::AnyFactory - ); + #decl_interfaces } .into() } @@ -199,6 +251,10 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream /// /// # Arguments /// {Implementation} -> {Interface} +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async. /// /// # Examples /// ``` @@ -218,9 +274,17 @@ pub fn declare_interface(input: TokenStream) -> TokenStream let DeclareInterfaceArgs { implementation, interface, + flags, } = parse_macro_input!(input); - generate_caster(&implementation, &interface).into() + let opt_async_flag = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async"); + + let is_async = + opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value); + + generate_caster(&implementation, &interface, is_async).into() } /// Declares the name of a dependency. diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs index 9bac09e..df743e2 100644 --- a/macros/src/libs/intertrait_macros/gen_caster.rs +++ b/macros/src/libs/intertrait_macros/gen_caster.rs @@ -22,15 +22,29 @@ const CASTER_FN_NAME_PREFIX: &[u8] = b"__"; const FN_BUF_LEN: usize = CASTER_FN_NAME_PREFIX.len() + Simple::LENGTH; -pub fn generate_caster(ty: &impl ToTokens, dst_trait: &impl ToTokens) -> TokenStream +pub fn generate_caster( + ty: &impl ToTokens, + dst_trait: &impl ToTokens, + sync: bool, +) -> TokenStream { let fn_ident = create_caster_fn_ident(); - let new_caster = quote! { - syrette::libs::intertrait::Caster::::new( - |from| from.downcast::<#ty>().unwrap(), - |from| from.downcast::<#ty>().unwrap(), - ) + let new_caster = if sync { + quote! { + syrette::libs::intertrait::Caster::::new_sync( + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap() + ) + } + } else { + quote! { + syrette::libs::intertrait::Caster::::new( + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap(), + ) + } }; quote! { diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs new file mode 100644 index 0000000..257a059 --- /dev/null +++ b/macros/src/macro_flag.rs @@ -0,0 +1,27 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{Ident, LitBool, Token}; + +#[derive(Debug)] +pub struct MacroFlag +{ + pub flag: Ident, + pub is_on: LitBool, +} + +impl Parse for MacroFlag +{ + fn parse(input: ParseStream) -> syn::Result + { + let input_forked = input.fork(); + + let flag: Ident = input_forked.parse()?; + + input.parse::()?; + + input.parse::()?; + + let is_on: LitBool = input.parse()?; + + Ok(Self { flag, is_on }) + } +} diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs index 4f2a594..0705853 100644 --- a/macros/src/util/mod.rs +++ b/macros/src/util/mod.rs @@ -1,3 +1,4 @@ pub mod item_impl; pub mod iterator_ext; +pub mod string; pub mod syn_path; diff --git a/macros/src/util/string.rs b/macros/src/util/string.rs new file mode 100644 index 0000000..90cccee --- /dev/null +++ b/macros/src/util/string.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use regex::Regex; + +static CAMELCASE_RE: Lazy = Lazy::new(|| Regex::new(r"([a-z])([A-Z])").unwrap()); + +pub fn camelcase_to_snakecase(camelcased: &str) -> String +{ + CAMELCASE_RE + .replace(camelcased, "${1}_$2") + .to_string() + .to_lowercase() +} -- cgit v1.2.3-18-g5258