diff options
Diffstat (limited to 'macros/src')
| -rw-r--r-- | macros/src/declare_interface_args.rs | 43 | ||||
| -rw-r--r-- | macros/src/factory_macro_args.rs | 44 | ||||
| -rw-r--r-- | macros/src/injectable_impl.rs | 102 | ||||
| -rw-r--r-- | macros/src/injectable_macro_args.rs | 55 | ||||
| -rw-r--r-- | macros/src/lib.rs | 108 | ||||
| -rw-r--r-- | macros/src/libs/intertrait_macros/gen_caster.rs | 26 | ||||
| -rw-r--r-- | macros/src/macro_flag.rs | 27 | ||||
| -rw-r--r-- | macros/src/util/mod.rs | 1 | ||||
| -rw-r--r-- | macros/src/util/string.rs | 12 | 
9 files changed, 328 insertions, 90 deletions
diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs index b54f458..bd2f24e 100644 --- a/macros/src/declare_interface_args.rs +++ b/macros/src/declare_interface_args.rs @@ -1,10 +1,17 @@  use syn::parse::{Parse, ParseStream, Result}; +use syn::punctuated::Punctuated;  use syn::{Path, Token, Type}; +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const DECLARE_INTERFACE_FLAGS: &[&str] = &["async"]; +  pub struct DeclareInterfaceArgs  {      pub implementation: Type,      pub interface: Path, +    pub flags: Punctuated<MacroFlag, Token![,]>,  }  impl Parse for DeclareInterfaceArgs @@ -15,9 +22,43 @@ impl Parse for DeclareInterfaceArgs          input.parse::<Token![->]>()?; +        let interface: Path = input.parse()?; + +        let flags = if input.peek(Token![,]) { +            input.parse::<Token![,]>()?; + +            let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?; + +            for flag in &flags { +                let flag_str = flag.flag.to_string(); + +                if !DECLARE_INTERFACE_FLAGS.contains(&flag_str.as_str()) { +                    return Err(input.error(format!( +                        "Unknown flag '{}'. Expected one of [ {} ]", +                        flag_str, +                        DECLARE_INTERFACE_FLAGS.join(",") +                    ))); +                } +            } + +            let flag_names = flags +                .iter() +                .map(|flag| flag.flag.to_string()) +                .collect::<Vec<_>>(); + +            if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { +                return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); +            } + +            flags +        } else { +            Punctuated::new() +        }; +          Ok(Self {              implementation, -            interface: input.parse()?, +            interface, +            flags,          })      }  } diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs new file mode 100644 index 0000000..57517d6 --- /dev/null +++ b/macros/src/factory_macro_args.rs @@ -0,0 +1,44 @@ +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::Token; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const FACTORY_MACRO_FLAGS: &[&str] = &["async"]; + +pub struct FactoryMacroArgs +{ +    pub flags: Punctuated<MacroFlag, Token![,]>, +} + +impl Parse for FactoryMacroArgs +{ +    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> +    { +        let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?; + +        for flag in &flags { +            let flag_str = flag.flag.to_string(); + +            if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { +                return Err(input.error(format!( +                    "Unknown flag '{}'. Expected one of [ {} ]", +                    flag_str, +                    FACTORY_MACRO_FLAGS.join(",") +                ))); +            } +        } + +        let flag_names = flags +            .iter() +            .map(|flag| flag.flag.to_string()) +            .collect::<Vec<_>>(); + +        if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { +            return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); +        } + +        Ok(Self { flags }) +    } +} diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 990b148..3565ef9 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type};  use crate::dependency::Dependency;  use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase;  use crate::util::syn_path::syn_path_to_string;  const DI_CONTAINER_VAR_NAME: &str = "di_container"; @@ -39,7 +40,8 @@ impl Parse for InjectableImpl  impl InjectableImpl  { -    pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream +    pub fn expand(&self, no_doc_hidden: bool, is_async: bool) +        -> proc_macro2::TokenStream      {          let Self {              dependencies, @@ -51,8 +53,6 @@ impl InjectableImpl          let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME);          let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); -        let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies); -          let maybe_doc_hidden = if no_doc_hidden {              quote! {}          } else { @@ -81,36 +81,78 @@ impl InjectableImpl              quote! {}          }; -        quote! { -            #original_impl +        let injectable_impl = if is_async { +            let async_get_dep_method_calls = +                Self::create_get_dep_method_calls(dependencies, true); + +            quote! { +                #maybe_doc_hidden +                #[syrette::libs::async_trait::async_trait] +                impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type +                { +                    async fn resolve( +                        #di_container_var: &syrette::async_di_container::AsyncDIContainer, +                        mut #dependency_history_var: Vec<&'static str>, +                    ) -> Result< +                        syrette::ptr::TransientPtr<Self>, +                        syrette::errors::injectable::InjectableError> +                    { +                        use std::any::type_name; + +                        use syrette::errors::injectable::InjectableError; + +                        let self_type_name = type_name::<#self_type>(); + +                        #maybe_prevent_circular_deps + +                        return Ok(syrette::ptr::TransientPtr::new(Self::new( +                            #(#async_get_dep_method_calls),* +                        ))); +                    } +                } -            #maybe_doc_hidden -            impl #generics syrette::interfaces::injectable::Injectable for #self_type { -                fn resolve( -                    #di_container_var: &syrette::DIContainer, -                    mut #dependency_history_var: Vec<&'static str>, -                ) -> Result< -                    syrette::ptr::TransientPtr<Self>, -                    syrette::errors::injectable::InjectableError> +            } +        } else { +            let get_dep_method_calls = +                Self::create_get_dep_method_calls(dependencies, false); + +            quote! { +                #maybe_doc_hidden +                impl #generics syrette::interfaces::injectable::Injectable for #self_type                  { -                    use std::any::type_name; +                    fn resolve( +                        #di_container_var: &syrette::DIContainer, +                        mut #dependency_history_var: Vec<&'static str>, +                    ) -> Result< +                        syrette::ptr::TransientPtr<Self>, +                        syrette::errors::injectable::InjectableError> +                    { +                        use std::any::type_name; -                    use syrette::errors::injectable::InjectableError; +                        use syrette::errors::injectable::InjectableError; -                    let self_type_name = type_name::<#self_type>(); +                        let self_type_name = type_name::<#self_type>(); -                    #maybe_prevent_circular_deps +                        #maybe_prevent_circular_deps -                    return Ok(syrette::ptr::TransientPtr::new(Self::new( -                        #(#get_dep_method_calls),* -                    ))); +                        return Ok(syrette::ptr::TransientPtr::new(Self::new( +                            #(#get_dep_method_calls),* +                        ))); +                    }                  }              } +        }; + +        quote! { +            #original_impl + +            #injectable_impl          }      }      fn create_get_dep_method_calls(          dependencies: &[Dependency], +        is_async: bool,      ) -> Vec<proc_macro2::TokenStream>      {          dependencies @@ -146,11 +188,25 @@ impl InjectableImpl              .map(|(method_call, dep_type)| {                  let ptr_name = dep_type.ptr.to_string(); -                let to_ptr = -                    format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase()); +                let to_ptr = format_ident!( +                    "{}", +                    camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) +                ); + +                let do_method_call = if is_async { +                    quote! { #method_call.await } +                } else { +                    quote! { #method_call } +                }; + +                let resolve_failed_error = if is_async { +                    quote! { InjectableError::AsyncResolveFailed } +                } else { +                    quote! { InjectableError::ResolveFailed } +                };                  quote! { -                    #method_call.map_err(|err| InjectableError::ResolveFailed { +                    #do_method_call.map_err(|err| #resolve_failed_error {                          reason: Box::new(err),                          affected: self_type_name                      })?.#to_ptr().unwrap() diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs index 43f8e11..6cc1d7e 100644 --- a/macros/src/injectable_macro_args.rs +++ b/macros/src/injectable_macro_args.rs @@ -1,49 +1,16 @@  use syn::parse::{Parse, ParseStream};  use syn::punctuated::Punctuated; -use syn::{braced, Ident, LitBool, Token, TypePath}; +use syn::{braced, Token, TypePath}; +use crate::macro_flag::MacroFlag;  use crate::util::iterator_ext::IteratorExt; -pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden"]; - -pub struct InjectableMacroFlag -{ -    pub flag: Ident, -    pub is_on: LitBool, -} - -impl Parse for InjectableMacroFlag -{ -    fn parse(input: ParseStream) -> syn::Result<Self> -    { -        let input_forked = input.fork(); - -        let flag: Ident = input_forked.parse()?; - -        let flag_str = flag.to_string(); - -        if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { -            return Err(input.error(format!( -                "Unknown flag '{}'. Expected one of [ {} ]", -                flag_str, -                INJECTABLE_MACRO_FLAGS.join(",") -            ))); -        } - -        input.parse::<Ident>()?; - -        input.parse::<Token![=]>()?; - -        let is_on: LitBool = input.parse()?; - -        Ok(Self { flag, is_on }) -    } -} +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"];  pub struct InjectableMacroArgs  {      pub interface: Option<TypePath>, -    pub flags: Punctuated<InjectableMacroFlag, Token![,]>, +    pub flags: Punctuated<MacroFlag, Token![,]>,  }  impl Parse for InjectableMacroArgs @@ -76,7 +43,19 @@ impl Parse for InjectableMacroArgs          braced!(braced_content in input); -        let flags = braced_content.parse_terminated(InjectableMacroFlag::parse)?; +        let flags = braced_content.parse_terminated(MacroFlag::parse)?; + +        for flag in &flags { +            let flag_str = flag.flag.to_string(); + +            if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { +                return Err(input.error(format!( +                    "Unknown flag '{}'. Expected one of [ {} ]", +                    flag_str, +                    INJECTABLE_MACRO_FLAGS.join(",") +                ))); +            } +        }          let flag_names = flags              .iter() diff --git a/macros/src/lib.rs b/macros/src/lib.rs index eb3a2be..40fbb53 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -2,7 +2,7 @@  #![deny(clippy::pedantic)]  #![deny(missing_docs)] -//! Macros for the [Syrette](https://crates.io/crates/syrette) crate. +//! Macros for the [Sy&rette](https://crates.io/crates/syrette) crate.  use proc_macro::TokenStream;  use quote::quote; @@ -10,10 +10,12 @@ use syn::{parse, parse_macro_input};  mod declare_interface_args;  mod dependency; +mod factory_macro_args;  mod factory_type_alias;  mod injectable_impl;  mod injectable_macro_args;  mod libs; +mod macro_flag;  mod named_attr_input;  mod util; @@ -31,6 +33,7 @@ use libs::intertrait_macros::gen_caster::generate_caster;  /// # Flags  /// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from  ///   documentation. +/// - `async` - Mark as async.  ///  /// # Panics  /// If the attributed item is not a impl. @@ -107,21 +110,31 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt  {      let InjectableMacroArgs { interface, flags } = parse_macro_input!(args_stream); -    let mut flags_iter = flags.iter(); - -    let no_doc_hidden = flags_iter +    let no_doc_hidden = flags +        .iter()          .find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden")          .map_or(false, |flag| flag.is_on.value); +    let is_async = flags +        .iter() +        .find(|flag| flag.flag.to_string().as_str() == "async") +        .map_or(false, |flag| flag.is_on.value); +      let injectable_impl: InjectableImpl = parse(impl_stream).unwrap(); -    let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden); +    let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async);      let maybe_decl_interface = if interface.is_some() {          let self_type = &injectable_impl.self_type; -        quote! { -            syrette::declare_interface!(#self_type -> #interface); +        if is_async { +            quote! { +                syrette::declare_interface!(#self_type -> #interface, async = true); +            } +        } else { +            quote! { +                syrette::declare_interface!(#self_type -> #interface); +            }          }      } else {          quote! {} @@ -139,6 +152,12 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt  ///  /// *This macro is only available if Syrette is built with the "factory" feature.*  /// +/// # Arguments +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async. +///  /// # Panics  /// If the attributed item is not a type alias.  /// @@ -166,8 +185,17 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt  /// ```  #[proc_macro_attribute]  #[cfg(feature = "factory")] -pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream  { +    use crate::factory_macro_args::FactoryMacroArgs; + +    let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); + +    let is_async = flags +        .iter() +        .find(|flag| flag.flag.to_string().as_str() == "async") +        .map_or(false, |flag| flag.is_on.value); +      let factory_type_alias::FactoryTypeAlias {          type_alias,          factory_interface, @@ -175,22 +203,46 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream          return_type,      } = parse(type_alias_stream).unwrap(); +    let decl_interfaces = if is_async { +        quote! { +            syrette::declare_interface!( +                syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< +                    #arg_types, +                    #return_type +                > -> #factory_interface, +                async = true +            ); + +            syrette::declare_interface!( +                syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< +                    #arg_types, +                    #return_type +                > -> syrette::interfaces::any_factory::AnyThreadsafeFactory, +                async = true +            ) +        } +    } else { +        quote! { +            syrette::declare_interface!( +                syrette::castable_factory::blocking::CastableFactory< +                    #arg_types, +                    #return_type +                > -> #factory_interface +            ); + +            syrette::declare_interface!( +                syrette::castable_factory::blocking::CastableFactory< +                    #arg_types, +                    #return_type +                > -> syrette::interfaces::any_factory::AnyFactory +            ); +        } +    }; +      quote! {          #type_alias -        syrette::declare_interface!( -            syrette::castable_factory::CastableFactory< -                #arg_types, -                #return_type -            > -> #factory_interface -        ); - -        syrette::declare_interface!( -            syrette::castable_factory::CastableFactory< -                #arg_types, -                #return_type -            > -> syrette::interfaces::any_factory::AnyFactory -        ); +        #decl_interfaces      }      .into()  } @@ -199,6 +251,10 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream  ///  /// # Arguments  /// {Implementation} -> {Interface} +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async.  ///  /// # Examples  /// ``` @@ -218,9 +274,17 @@ pub fn declare_interface(input: TokenStream) -> TokenStream      let DeclareInterfaceArgs {          implementation,          interface, +        flags,      } = parse_macro_input!(input); -    generate_caster(&implementation, &interface).into() +    let opt_async_flag = flags +        .iter() +        .find(|flag| flag.flag.to_string().as_str() == "async"); + +    let is_async = +        opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value); + +    generate_caster(&implementation, &interface, is_async).into()  }  /// Declares the name of a dependency. diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs index 9bac09e..df743e2 100644 --- a/macros/src/libs/intertrait_macros/gen_caster.rs +++ b/macros/src/libs/intertrait_macros/gen_caster.rs @@ -22,15 +22,29 @@ const CASTER_FN_NAME_PREFIX: &[u8] = b"__";  const FN_BUF_LEN: usize = CASTER_FN_NAME_PREFIX.len() + Simple::LENGTH; -pub fn generate_caster(ty: &impl ToTokens, dst_trait: &impl ToTokens) -> TokenStream +pub fn generate_caster( +    ty: &impl ToTokens, +    dst_trait: &impl ToTokens, +    sync: bool, +) -> TokenStream  {      let fn_ident = create_caster_fn_ident(); -    let new_caster = quote! { -        syrette::libs::intertrait::Caster::<dyn #dst_trait>::new( -            |from| from.downcast::<#ty>().unwrap(), -            |from| from.downcast::<#ty>().unwrap(), -        ) +    let new_caster = if sync { +        quote! { +            syrette::libs::intertrait::Caster::<dyn #dst_trait>::new_sync( +                |from| from.downcast::<#ty>().unwrap(), +                |from| from.downcast::<#ty>().unwrap(), +                |from| from.downcast::<#ty>().unwrap() +            ) +        } +    } else { +        quote! { +            syrette::libs::intertrait::Caster::<dyn #dst_trait>::new( +                |from| from.downcast::<#ty>().unwrap(), +                |from| from.downcast::<#ty>().unwrap(), +            ) +        }      };      quote! { diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs new file mode 100644 index 0000000..257a059 --- /dev/null +++ b/macros/src/macro_flag.rs @@ -0,0 +1,27 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{Ident, LitBool, Token}; + +#[derive(Debug)] +pub struct MacroFlag +{ +    pub flag: Ident, +    pub is_on: LitBool, +} + +impl Parse for MacroFlag +{ +    fn parse(input: ParseStream) -> syn::Result<Self> +    { +        let input_forked = input.fork(); + +        let flag: Ident = input_forked.parse()?; + +        input.parse::<Ident>()?; + +        input.parse::<Token![=]>()?; + +        let is_on: LitBool = input.parse()?; + +        Ok(Self { flag, is_on }) +    } +} diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs index 4f2a594..0705853 100644 --- a/macros/src/util/mod.rs +++ b/macros/src/util/mod.rs @@ -1,3 +1,4 @@  pub mod item_impl;  pub mod iterator_ext; +pub mod string;  pub mod syn_path; diff --git a/macros/src/util/string.rs b/macros/src/util/string.rs new file mode 100644 index 0000000..90cccee --- /dev/null +++ b/macros/src/util/string.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use regex::Regex; + +static CAMELCASE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"([a-z])([A-Z])").unwrap()); + +pub fn camelcase_to_snakecase(camelcased: &str) -> String +{ +    CAMELCASE_RE +        .replace(camelcased, "${1}_$2") +        .to_string() +        .to_lowercase() +}  | 
