diff options
author | HampusM <hampus@hampusmat.com> | 2022-08-29 20:52:56 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-08-29 21:01:32 +0200 |
commit | 080cc42bb1da09059dbc35049a7ded0649961e0c (patch) | |
tree | 307ee564124373616022c1ba2b4d5af80845cd92 | |
parent | 6e31d8f9e46fece348f329763b39b9c6f2741c07 (diff) |
feat: implement async functionality
41 files changed, 2132 insertions, 181 deletions
@@ -15,6 +15,7 @@ all-features = true default = ["prevent-circular"] factory = ["syrette_macros/factory"] prevent-circular = ["syrette_macros/prevent-circular"] +async = ["async-trait"] [[example]] name = "factory" @@ -24,6 +25,10 @@ required-features = ["factory"] name = "with-3rd-party" required-features = ["factory"] +[[example]] +name = "async" +required-features = ["async"] + [dependencies] syrette_macros = { path = "./macros", version = "0.3.0" } linkme = "0.3.0" @@ -33,11 +38,13 @@ thiserror = "1.0.32" strum = "0.24.1" strum_macros = "0.24.3" paste = "1.0.8" +async-trait = { version = "0.1.57", optional = true } [dev_dependencies] mockall = "0.11.1" anyhow = "1.0.62" third-party-lib = { path = "./examples/with-3rd-party/third-party-lib" } +tokio = { version = "1.20.1", features = ["full"] } [workspace] members = [ diff --git a/examples/async/animals/cat.rs b/examples/async/animals/cat.rs new file mode 100644 index 0000000..b1e6f27 --- /dev/null +++ b/examples/async/animals/cat.rs @@ -0,0 +1,22 @@ +use syrette::injectable; + +use crate::interfaces::cat::ICat; + +pub struct Cat {} + +#[injectable(ICat, { async = true })] +impl Cat +{ + pub fn new() -> Self + { + Self {} + } +} + +impl ICat for Cat +{ + fn meow(&self) + { + println!("Meow!"); + } +} diff --git a/examples/async/animals/dog.rs b/examples/async/animals/dog.rs new file mode 100644 index 0000000..d1b33f9 --- /dev/null +++ b/examples/async/animals/dog.rs @@ -0,0 +1,22 @@ +use syrette::injectable; + +use crate::interfaces::dog::IDog; + +pub struct Dog {} + +#[injectable(IDog, { async = true })] +impl Dog +{ + pub fn new() -> Self + { + Self {} + } +} + +impl IDog for Dog +{ + fn woof(&self) + { + println!("Woof!"); + } +} diff --git a/examples/async/animals/human.rs b/examples/async/animals/human.rs new file mode 100644 index 0000000..140f27c --- /dev/null +++ b/examples/async/animals/human.rs @@ -0,0 +1,36 @@ +use syrette::injectable; +use syrette::ptr::{ThreadsafeSingletonPtr, TransientPtr}; + +use crate::interfaces::cat::ICat; +use crate::interfaces::dog::IDog; +use crate::interfaces::human::IHuman; + +pub struct Human +{ + dog: ThreadsafeSingletonPtr<dyn IDog>, + cat: TransientPtr<dyn ICat>, +} + +#[injectable(IHuman, { async = true })] +impl Human +{ + pub fn new(dog: ThreadsafeSingletonPtr<dyn IDog>, cat: TransientPtr<dyn ICat>) + -> Self + { + Self { dog, cat } + } +} + +impl IHuman for Human +{ + fn make_pets_make_sounds(&self) + { + println!("Hi doggy!"); + + self.dog.woof(); + + println!("Hi kitty!"); + + self.cat.meow(); + } +} diff --git a/examples/async/animals/mod.rs b/examples/async/animals/mod.rs new file mode 100644 index 0000000..5444978 --- /dev/null +++ b/examples/async/animals/mod.rs @@ -0,0 +1,3 @@ +pub mod cat; +pub mod dog; +pub mod human; diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs new file mode 100644 index 0000000..b640712 --- /dev/null +++ b/examples/async/bootstrap.rs @@ -0,0 +1,28 @@ +use anyhow::Result; +use syrette::async_di_container::AsyncDIContainer; + +// Concrete implementations +use crate::animals::cat::Cat; +use crate::animals::dog::Dog; +use crate::animals::human::Human; +// +// Interfaces +use crate::interfaces::cat::ICat; +use crate::interfaces::dog::IDog; +use crate::interfaces::human::IHuman; + +pub async fn bootstrap() -> Result<AsyncDIContainer> +{ + let mut di_container = AsyncDIContainer::new(); + + di_container + .bind::<dyn IDog>() + .to::<Dog>()? + .in_singleton_scope() + .await?; + + di_container.bind::<dyn ICat>().to::<Cat>()?; + di_container.bind::<dyn IHuman>().to::<Human>()?; + + Ok(di_container) +} diff --git a/examples/async/interfaces/cat.rs b/examples/async/interfaces/cat.rs new file mode 100644 index 0000000..478f7e0 --- /dev/null +++ b/examples/async/interfaces/cat.rs @@ -0,0 +1,4 @@ +pub trait ICat: Send + Sync +{ + fn meow(&self); +} diff --git a/examples/async/interfaces/dog.rs b/examples/async/interfaces/dog.rs new file mode 100644 index 0000000..a6ed111 --- /dev/null +++ b/examples/async/interfaces/dog.rs @@ -0,0 +1,4 @@ +pub trait IDog: Send + Sync +{ + fn woof(&self); +} diff --git a/examples/async/interfaces/human.rs b/examples/async/interfaces/human.rs new file mode 100644 index 0000000..18f9d63 --- /dev/null +++ b/examples/async/interfaces/human.rs @@ -0,0 +1,4 @@ +pub trait IHuman: Send + Sync +{ + fn make_pets_make_sounds(&self); +} diff --git a/examples/async/interfaces/mod.rs b/examples/async/interfaces/mod.rs new file mode 100644 index 0000000..5444978 --- /dev/null +++ b/examples/async/interfaces/mod.rs @@ -0,0 +1,3 @@ +pub mod cat; +pub mod dog; +pub mod human; diff --git a/examples/async/main.rs b/examples/async/main.rs new file mode 100644 index 0000000..f72ff39 --- /dev/null +++ b/examples/async/main.rs @@ -0,0 +1,52 @@ +#![deny(clippy::all)] +#![deny(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +use std::sync::Arc; + +use anyhow::Result; +use tokio::spawn; +use tokio::sync::Mutex; + +mod animals; +mod bootstrap; +mod interfaces; + +use bootstrap::bootstrap; +use interfaces::dog::IDog; +use interfaces::human::IHuman; + +#[tokio::main] +async fn main() -> Result<()> +{ + println!("Hello, world!"); + + let di_container = Arc::new(Mutex::new(bootstrap().await?)); + + { + let dog = di_container + .lock() + .await + .get::<dyn IDog>() + .await? + .threadsafe_singleton()?; + + dog.woof(); + } + + spawn(async move { + let human = di_container + .lock() + .await + .get::<dyn IHuman>() + .await? + .transient()?; + + human.make_pets_make_sounds(); + + Ok::<_, anyhow::Error>(()) + }) + .await??; + + Ok(()) +} diff --git a/macros/Cargo.toml b/macros/Cargo.toml index a929b08..28cb4c0 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -23,6 +23,8 @@ syn = { version = "1.0.96", features = ["full"] } quote = "1.0.18" proc-macro2 = "1.0.40" uuid = { version = "0.8", features = ["v4"] } +regex = "1.6.0" +once_cell = "1.13.1" [dev_dependencies] syrette = { version = "0.3.0", path = "..", features = ["factory"] } diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs index b54f458..bd2f24e 100644 --- a/macros/src/declare_interface_args.rs +++ b/macros/src/declare_interface_args.rs @@ -1,10 +1,17 @@ use syn::parse::{Parse, ParseStream, Result}; +use syn::punctuated::Punctuated; use syn::{Path, Token, Type}; +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const DECLARE_INTERFACE_FLAGS: &[&str] = &["async"]; + pub struct DeclareInterfaceArgs { pub implementation: Type, pub interface: Path, + pub flags: Punctuated<MacroFlag, Token![,]>, } impl Parse for DeclareInterfaceArgs @@ -15,9 +22,43 @@ impl Parse for DeclareInterfaceArgs input.parse::<Token![->]>()?; + let interface: Path = input.parse()?; + + let flags = if input.peek(Token![,]) { + input.parse::<Token![,]>()?; + + let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !DECLARE_INTERFACE_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + DECLARE_INTERFACE_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::<Vec<_>>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + flags + } else { + Punctuated::new() + }; + Ok(Self { implementation, - interface: input.parse()?, + interface, + flags, }) } } diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs new file mode 100644 index 0000000..57517d6 --- /dev/null +++ b/macros/src/factory_macro_args.rs @@ -0,0 +1,44 @@ +use syn::parse::Parse; +use syn::punctuated::Punctuated; +use syn::Token; + +use crate::macro_flag::MacroFlag; +use crate::util::iterator_ext::IteratorExt; + +pub const FACTORY_MACRO_FLAGS: &[&str] = &["async"]; + +pub struct FactoryMacroArgs +{ + pub flags: Punctuated<MacroFlag, Token![,]>, +} + +impl Parse for FactoryMacroArgs +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> + { + let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + FACTORY_MACRO_FLAGS.join(",") + ))); + } + } + + let flag_names = flags + .iter() + .map(|flag| flag.flag.to_string()) + .collect::<Vec<_>>(); + + if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() { + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name))); + } + + Ok(Self { flags }) + } +} diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 990b148..3565ef9 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type}; use crate::dependency::Dependency; use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::string::camelcase_to_snakecase; use crate::util::syn_path::syn_path_to_string; const DI_CONTAINER_VAR_NAME: &str = "di_container"; @@ -39,7 +40,8 @@ impl Parse for InjectableImpl impl InjectableImpl { - pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream + pub fn expand(&self, no_doc_hidden: bool, is_async: bool) + -> proc_macro2::TokenStream { let Self { dependencies, @@ -51,8 +53,6 @@ impl InjectableImpl let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies); - let maybe_doc_hidden = if no_doc_hidden { quote! {} } else { @@ -81,36 +81,78 @@ impl InjectableImpl quote! {} }; - quote! { - #original_impl + let injectable_impl = if is_async { + let async_get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, true); + + quote! { + #maybe_doc_hidden + #[syrette::libs::async_trait::async_trait] + impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type + { + async fn resolve( + #di_container_var: &syrette::async_di_container::AsyncDIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr<Self>, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; + + use syrette::errors::injectable::InjectableError; + + let self_type_name = type_name::<#self_type>(); + + #maybe_prevent_circular_deps + + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#async_get_dep_method_calls),* + ))); + } + } - #maybe_doc_hidden - impl #generics syrette::interfaces::injectable::Injectable for #self_type { - fn resolve( - #di_container_var: &syrette::DIContainer, - mut #dependency_history_var: Vec<&'static str>, - ) -> Result< - syrette::ptr::TransientPtr<Self>, - syrette::errors::injectable::InjectableError> + } + } else { + let get_dep_method_calls = + Self::create_get_dep_method_calls(dependencies, false); + + quote! { + #maybe_doc_hidden + impl #generics syrette::interfaces::injectable::Injectable for #self_type { - use std::any::type_name; + fn resolve( + #di_container_var: &syrette::DIContainer, + mut #dependency_history_var: Vec<&'static str>, + ) -> Result< + syrette::ptr::TransientPtr<Self>, + syrette::errors::injectable::InjectableError> + { + use std::any::type_name; - use syrette::errors::injectable::InjectableError; + use syrette::errors::injectable::InjectableError; - let self_type_name = type_name::<#self_type>(); + let self_type_name = type_name::<#self_type>(); - #maybe_prevent_circular_deps + #maybe_prevent_circular_deps - return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dep_method_calls),* - ))); + return Ok(syrette::ptr::TransientPtr::new(Self::new( + #(#get_dep_method_calls),* + ))); + } } } + }; + + quote! { + #original_impl + + #injectable_impl } } fn create_get_dep_method_calls( dependencies: &[Dependency], + is_async: bool, ) -> Vec<proc_macro2::TokenStream> { dependencies @@ -146,11 +188,25 @@ impl InjectableImpl .map(|(method_call, dep_type)| { let ptr_name = dep_type.ptr.to_string(); - let to_ptr = - format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase()); + let to_ptr = format_ident!( + "{}", + camelcase_to_snakecase(&ptr_name.replace("Ptr", "")) + ); + + let do_method_call = if is_async { + quote! { #method_call.await } + } else { + quote! { #method_call } + }; + + let resolve_failed_error = if is_async { + quote! { InjectableError::AsyncResolveFailed } + } else { + quote! { InjectableError::ResolveFailed } + }; quote! { - #method_call.map_err(|err| InjectableError::ResolveFailed { + #do_method_call.map_err(|err| #resolve_failed_error { reason: Box::new(err), affected: self_type_name })?.#to_ptr().unwrap() diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs index 43f8e11..6cc1d7e 100644 --- a/macros/src/injectable_macro_args.rs +++ b/macros/src/injectable_macro_args.rs @@ -1,49 +1,16 @@ use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::{braced, Ident, LitBool, Token, TypePath}; +use syn::{braced, Token, TypePath}; +use crate::macro_flag::MacroFlag; use crate::util::iterator_ext::IteratorExt; -pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden"]; - -pub struct InjectableMacroFlag -{ - pub flag: Ident, - pub is_on: LitBool, -} - -impl Parse for InjectableMacroFlag -{ - fn parse(input: ParseStream) -> syn::Result<Self> - { - let input_forked = input.fork(); - - let flag: Ident = input_forked.parse()?; - - let flag_str = flag.to_string(); - - if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { - return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, - INJECTABLE_MACRO_FLAGS.join(",") - ))); - } - - input.parse::<Ident>()?; - - input.parse::<Token![=]>()?; - - let is_on: LitBool = input.parse()?; - - Ok(Self { flag, is_on }) - } -} +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"]; pub struct InjectableMacroArgs { pub interface: Option<TypePath>, - pub flags: Punctuated<InjectableMacroFlag, Token![,]>, + pub flags: Punctuated<MacroFlag, Token![,]>, } impl Parse for InjectableMacroArgs @@ -76,7 +43,19 @@ impl Parse for InjectableMacroArgs braced!(braced_content in input); - let flags = braced_content.parse_terminated(InjectableMacroFlag::parse)?; + let flags = braced_content.parse_terminated(MacroFlag::parse)?; + + for flag in &flags { + let flag_str = flag.flag.to_string(); + + if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) { + return Err(input.error(format!( + "Unknown flag '{}'. Expected one of [ {} ]", + flag_str, + INJECTABLE_MACRO_FLAGS.join(",") + ))); + } + } let flag_names = flags .iter() diff --git a/macros/src/lib.rs b/macros/src/lib.rs index eb3a2be..40fbb53 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -2,7 +2,7 @@ #![deny(clippy::pedantic)] #![deny(missing_docs)] -//! Macros for the [Syrette](https://crates.io/crates/syrette) crate. +//! Macros for the [Sy&rette](https://crates.io/crates/syrette) crate. use proc_macro::TokenStream; use quote::quote; @@ -10,10 +10,12 @@ use syn::{parse, parse_macro_input}; mod declare_interface_args; mod dependency; +mod factory_macro_args; mod factory_type_alias; mod injectable_impl; mod injectable_macro_args; mod libs; +mod macro_flag; mod named_attr_input; mod util; @@ -31,6 +33,7 @@ use libs::intertrait_macros::gen_caster::generate_caster; /// # Flags /// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from /// documentation. +/// - `async` - Mark as async. /// /// # Panics /// If the attributed item is not a impl. @@ -107,21 +110,31 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt { let InjectableMacroArgs { interface, flags } = parse_macro_input!(args_stream); - let mut flags_iter = flags.iter(); - - let no_doc_hidden = flags_iter + let no_doc_hidden = flags + .iter() .find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden") .map_or(false, |flag| flag.is_on.value); + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + let injectable_impl: InjectableImpl = parse(impl_stream).unwrap(); - let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden); + let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async); let maybe_decl_interface = if interface.is_some() { let self_type = &injectable_impl.self_type; - quote! { - syrette::declare_interface!(#self_type -> #interface); + if is_async { + quote! { + syrette::declare_interface!(#self_type -> #interface, async = true); + } + } else { + quote! { + syrette::declare_interface!(#self_type -> #interface); + } } } else { quote! {} @@ -139,6 +152,12 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt /// /// *This macro is only available if Syrette is built with the "factory" feature.* /// +/// # Arguments +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async. +/// /// # Panics /// If the attributed item is not a type alias. /// @@ -166,8 +185,17 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt /// ``` #[proc_macro_attribute] #[cfg(feature = "factory")] -pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream { + use crate::factory_macro_args::FactoryMacroArgs; + + let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); + + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + let factory_type_alias::FactoryTypeAlias { type_alias, factory_interface, @@ -175,22 +203,46 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream return_type, } = parse(type_alias_stream).unwrap(); + let decl_interfaces = if is_async { + quote! { + syrette::declare_interface!( + syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< + #arg_types, + #return_type + > -> #factory_interface, + async = true + ); + + syrette::declare_interface!( + syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< + #arg_types, + #return_type + > -> syrette::interfaces::any_factory::AnyThreadsafeFactory, + async = true + ) + } + } else { + quote! { + syrette::declare_interface!( + syrette::castable_factory::blocking::CastableFactory< + #arg_types, + #return_type + > -> #factory_interface + ); + + syrette::declare_interface!( + syrette::castable_factory::blocking::CastableFactory< + #arg_types, + #return_type + > -> syrette::interfaces::any_factory::AnyFactory + ); + } + }; + quote! { #type_alias - syrette::declare_interface!( - syrette::castable_factory::CastableFactory< - #arg_types, - #return_type - > -> #factory_interface - ); - - syrette::declare_interface!( - syrette::castable_factory::CastableFactory< - #arg_types, - #return_type - > -> syrette::interfaces::any_factory::AnyFactory - ); + #decl_interfaces } .into() } @@ -199,6 +251,10 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream /// /// # Arguments /// {Implementation} -> {Interface} +/// * (Zero or more) Flags. Like `a = true, b = false` +/// +/// # Flags +/// - `async` - Mark as async. /// /// # Examples /// ``` @@ -218,9 +274,17 @@ pub fn declare_interface(input: TokenStream) -> TokenStream let DeclareInterfaceArgs { implementation, interface, + flags, } = parse_macro_input!(input); - generate_caster(&implementation, &interface).into() + let opt_async_flag = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async"); + + let is_async = + opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value); + + generate_caster(&implementation, &interface, is_async).into() } /// Declares the name of a dependency. diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs index 9bac09e..df743e2 100644 --- a/macros/src/libs/intertrait_macros/gen_caster.rs +++ b/macros/src/libs/intertrait_macros/gen_caster.rs @@ -22,15 +22,29 @@ const CASTER_FN_NAME_PREFIX: &[u8] = b"__"; const FN_BUF_LEN: usize = CASTER_FN_NAME_PREFIX.len() + Simple::LENGTH; -pub fn generate_caster(ty: &impl ToTokens, dst_trait: &impl ToTokens) -> TokenStream +pub fn generate_caster( + ty: &impl ToTokens, + dst_trait: &impl ToTokens, + sync: bool, +) -> TokenStream { let fn_ident = create_caster_fn_ident(); - let new_caster = quote! { - syrette::libs::intertrait::Caster::<dyn #dst_trait>::new( - |from| from.downcast::<#ty>().unwrap(), - |from| from.downcast::<#ty>().unwrap(), - ) + let new_caster = if sync { + quote! { + syrette::libs::intertrait::Caster::<dyn #dst_trait>::new_sync( + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap() + ) + } + } else { + quote! { + syrette::libs::intertrait::Caster::<dyn #dst_trait>::new( + |from| from.downcast::<#ty>().unwrap(), + |from| from.downcast::<#ty>().unwrap(), + ) + } }; quote! { diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs new file mode 100644 index 0000000..257a059 --- /dev/null +++ b/macros/src/macro_flag.rs @@ -0,0 +1,27 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{Ident, LitBool, Token}; + +#[derive(Debug)] +pub struct MacroFlag +{ + pub flag: Ident, + pub is_on: LitBool, +} + +impl Parse for MacroFlag +{ + fn parse(input: ParseStream) -> syn::Result<Self> + { + let input_forked = input.fork(); + + let flag: Ident = input_forked.parse()?; + + input.parse::<Ident>()?; + + input.parse::<Token![=]>()?; + + let is_on: LitBool = input.parse()?; + + Ok(Self { flag, is_on }) + } +} diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs index 4f2a594..0705853 100644 --- a/macros/src/util/mod.rs +++ b/macros/src/util/mod.rs @@ -1,3 +1,4 @@ pub mod item_impl; pub mod iterator_ext; +pub mod string; pub mod syn_path; diff --git a/macros/src/util/string.rs b/macros/src/util/string.rs new file mode 100644 index 0000000..90cccee --- /dev/null +++ b/macros/src/util/string.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use regex::Regex; + +static CAMELCASE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"([a-z])([A-Z])").unwrap()); + +pub fn camelcase_to_snakecase(camelcased: &str) -> String +{ + CAMELCASE_RE + .replace(camelcased, "${1}_$2") + .to_string() + .to_lowercase() +} diff --git a/src/async_di_container.rs b/src/async_di_container.rs new file mode 100644 index 0000000..374746f --- /dev/null +++ b/src/async_di_container.rs @@ -0,0 +1,1110 @@ +//! Asynchronous dependency injection container. +//! +//! # Examples +//! ``` +//! use std::collections::HashMap; +//! use std::error::Error; +//! +//! use syrette::{injectable, AsyncDIContainer}; +//! +//! trait IDatabaseService +//! { +//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>; +//! } +//! +//! struct DatabaseService {} +//! +//! #[injectable(IDatabaseService, { async = true })] +//! impl DatabaseService +//! { +//! fn new() -> Self +//! { +//! Self {} +//! } +//! } +//! +//! impl IDatabaseService for DatabaseService +//! { +//! fn get_all_records(&self, table_name: String) -> HashMap<String, String> +//! { +//! // Do stuff here +//! HashMap::<String, String>::new() +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn Error>> +//! { +//! let mut di_container = AsyncDIContainer::new(); +//! +//! di_container +//! .bind::<dyn IDatabaseService>() +//! .to::<DatabaseService>()?; +//! +//! let database_service = di_container +//! .get::<dyn IDatabaseService>() +//! .await? +//! .transient()?; +//! +//! Ok(()) +//! } +//! ``` +//! +//! --- +//! +//! *This module is only available if Syrette is built with the "async" feature.* +use std::any::type_name; +use std::marker::PhantomData; + +#[cfg(feature = "factory")] +use crate::castable_factory::threadsafe::ThreadsafeCastableFactory; +use crate::di_container_binding_map::DIContainerBindingMap; +use crate::errors::async_di_container::{ + AsyncBindingBuilderError, + AsyncBindingScopeConfiguratorError, + AsyncBindingWhenConfiguratorError, + AsyncDIContainerError, +}; +use crate::interfaces::async_injectable::AsyncInjectable; +use crate::libs::intertrait::cast::{CastArc, CastBox}; +use crate::provider::r#async::{ + AsyncProvidable, + AsyncSingletonProvider, + AsyncTransientTypeProvider, + IAsyncProvider, +}; +use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; + +/// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. +pub struct AsyncBindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + di_container: &'di_container mut AsyncDIContainer, + interface_phantom: PhantomData<Interface>, +} + +impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + } + } + + /// Configures the binding to have a name. + /// + /// # Errors + /// Will return Err if no binding for the interface already exists. + pub fn when_named( + &mut self, + name: &'static str, + ) -> Result<(), AsyncBindingWhenConfiguratorError> + { + let binding = self + .di_container + .bindings + .remove::<Interface>(None) + .map_or_else( + || { + Err(AsyncBindingWhenConfiguratorError::BindingNotFound( + type_name::<Interface>(), + )) + }, + Ok, + )?; + + self.di_container + .bindings + .set::<Interface>(Some(name), binding); + + Ok(()) + } +} + +/// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. +pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +where + Interface: 'static + ?Sized, + Implementation: AsyncInjectable, +{ + di_container: &'di_container mut AsyncDIContainer, + interface_phantom: PhantomData<Interface>, + implementation_phantom: PhantomData<Implementation>, +} + +impl<'di_container, Interface, Implementation> + AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +where + Interface: 'static + ?Sized, + Implementation: AsyncInjectable, +{ + fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + implementation_phantom: PhantomData, + } + } + + /// Configures the binding to be in a transient scope. + /// + /// This is the default. + pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator<Interface> + { + self.di_container.bindings.set::<Interface>( + None, + Box::new(AsyncTransientTypeProvider::<Implementation>::new()), + ); + + AsyncBindingWhenConfigurator::new(self.di_container) + } + + /// Configures the binding to be in a singleton scope. + /// + /// # Errors + /// Will return Err if resolving the implementation fails. + pub async fn in_singleton_scope( + &mut self, + ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingScopeConfiguratorError> + { + let singleton: ThreadsafeSingletonPtr<Implementation> = + ThreadsafeSingletonPtr::from( + Implementation::resolve(self.di_container, Vec::new()) + .await + .map_err( + AsyncBindingScopeConfiguratorError::SingletonResolveFailed, + )?, + ); + + self.di_container + .bindings + .set::<Interface>(None, Box::new(AsyncSingletonProvider::new(singleton))); + + Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + } +} + +/// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. +pub struct AsyncBindingBuilder<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + di_container: &'di_container mut AsyncDIContainer, + interface_phantom: PhantomData<Interface>, +} + +impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + } + } + + /// Creates a binding of type `Interface` to type `Implementation` inside of the + /// associated [`AsyncDIContainer`]. + /// + /// The scope of the binding is transient. But that can be changed by using the + /// returned [`AsyncBindingScopeConfigurator`] + /// + /// # Errors + /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for + /// the interface. + pub fn to<Implementation>( + &mut self, + ) -> Result< + AsyncBindingScopeConfigurator<Interface, Implementation>, + AsyncBindingBuilderError, + > + where + Implementation: AsyncInjectable, + { + if self.di_container.bindings.has::<Interface>(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } + + let mut binding_scope_configurator = + AsyncBindingScopeConfigurator::new(self.di_container); + + binding_scope_configurator.in_transient_scope(); + + Ok(binding_scope_configurator) + } + + /// Creates a binding of factory type `Interface` to a factory inside of the + /// associated [`AsyncDIContainer`]. + /// + /// *This function is only available if Syrette is built with the "factory" feature.* + /// + /// # Errors + /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for + /// the interface. + #[cfg(feature = "factory")] + pub fn to_factory<Args, Return>( + &mut self, + factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>> + + Send + + Sync), + ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError> + where + Args: 'static, + Return: 'static + ?Sized, + Interface: crate::interfaces::factory::IFactory<Args, Return>, + { + if self.di_container.bindings.has::<Interface>(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } + + let factory_impl = ThreadsafeCastableFactory::new(factory_func); + + self.di_container.bindings.set::<Interface>( + None, + Box::new(crate::provider::r#async::AsyncFactoryProvider::new( + crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + )), + ); + + Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + } + + /// Creates a binding of type `Interface` to a factory that takes no arguments + /// inside of the associated [`AsyncDIContainer`]. + /// + /// *This function is only available if Syrette is built with the "factory" feature.* + /// + /// # Errors + /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for + /// the interface. + #[cfg(feature = "factory")] + pub fn to_default_factory<Return>( + &mut self, + factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>> + + Send + + Sync), + ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError> + where + Return: 'static + ?Sized, + { + if self.di_container.bindings.has::<Interface>(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } + + let factory_impl = ThreadsafeCastableFactory::new(factory_func); + + self.di_container.bindings.set::<Interface>( + None, + Box::new(crate::provider::r#async::AsyncFactoryProvider::new( + crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + )), + ); + + Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + } +} + +/// Dependency injection container. +pub struct AsyncDIContainer +{ + bindings: DIContainerBindingMap<dyn IAsyncProvider>, +} + +impl AsyncDIContainer +{ + /// Returns a new `AsyncDIContainer`. + #[must_use] + pub fn new() -> Self + { + Self { + bindings: DIContainerBindingMap::new(), + } + } + + /// Returns a new [`AsyncBindingBuilder`] for the given interface. + pub fn bind<Interface>(&mut self) -> AsyncBindingBuilder<Interface> + where + Interface: 'static + ?Sized, + { + AsyncBindingBuilder::<Interface>::new(self) + } + + /// Returns the type bound with `Interface`. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` exists + /// - Resolving the binding for fails + /// - Casting the binding for fails + pub async fn get<Interface>( + &self, + ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> + where + Interface: 'static + ?Sized, + { + self.get_bound::<Interface>(Vec::new(), None).await + } + + /// Returns the type bound with `Interface` and the specified name. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` with name `name` exists + /// - Resolving the binding fails + /// - Casting the binding for fails + pub async fn get_named<Interface>( + &self, + name: &'static str, + ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> + where + Interface: 'static + ?Sized, + { + self.get_bound::<Interface>(Vec::new(), Some(name)).await + } + + #[doc(hidden)] + pub async fn get_bound<Interface>( + &self, + dependency_history: Vec<&'static str>, + name: Option<&'static str>, + ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> + where + Interface: 'static + ?Sized, + { + let binding_providable = self + .get_binding_providable::<Interface>(name, dependency_history) + .await?; + + Self::handle_binding_providable(binding_providable) + } + + fn handle_binding_providable<Interface>( + binding_providable: AsyncProvidable, + ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> + where + Interface: 'static + ?Sized, + { + match binding_providable { + AsyncProvidable::Transient(transient_binding) => { + Ok(SomeThreadsafePtr::Transient( + transient_binding.cast::<Interface>().map_err(|_| { + AsyncDIContainerError::CastFailed(type_name::<Interface>()) + })?, + )) + } + AsyncProvidable::Singleton(singleton_binding) => { + Ok(SomeThreadsafePtr::ThreadsafeSingleton( + singleton_binding.cast::<Interface>().map_err(|_| { + AsyncDIContainerError::CastFailed(type_name::<Interface>()) + })?, + )) + } + #[cfg(feature = "factory")] + AsyncProvidable::Factory(factory_binding) => { + match factory_binding.clone().cast::<Interface>() { + Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)), + Err(_err) => { + use crate::interfaces::factory::IFactory; + + let default_factory = + factory_binding + .cast::<dyn IFactory<(), Interface>>() + .map_err(|_| { + AsyncDIContainerError::CastFailed(type_name::< + Interface, + >( + )) + })?; + + Ok(SomeThreadsafePtr::Transient(default_factory())) + } + } + } + } + } + + async fn get_binding_providable<Interface>( + &self, + name: Option<&'static str>, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, AsyncDIContainerError> + where + Interface: 'static + ?Sized, + { + self.bindings + .get::<Interface>(name) + .map_or_else( + || { + Err(AsyncDIContainerError::BindingNotFound { + interface: type_name::<Interface>(), + name, + }) + }, + Ok, + )? + .provide(self, dependency_history) + .await + .map_err(|err| AsyncDIContainerError::BindingResolveFailed { + reason: err, + interface: type_name::<Interface>(), + }) + } +} + +impl Default for AsyncDIContainer +{ + fn default() -> Self + { + Self::new() + } +} + +#[cfg(test)] +mod tests +{ + use std::error::Error; + + use async_trait::async_trait; + use mockall::mock; + + use super::*; + use crate::errors::injectable::InjectableError; + use crate::ptr::TransientPtr; + + mod subjects + { + //! Test subjects. + + use std::fmt::Debug; + + use async_trait::async_trait; + use syrette_macros::declare_interface; + + use super::AsyncDIContainer; + use crate::interfaces::async_injectable::AsyncInjectable; + use crate::ptr::TransientPtr; + + pub trait IUserManager + { + fn add_user(&self, user_id: i128); + + fn remove_user(&self, user_id: i128); + } + + pub struct UserManager {} + + impl UserManager + { + pub fn new() -> Self + { + Self {} + } + } + + impl IUserManager for UserManager + { + fn add_user(&self, _user_id: i128) + { + // ... + } + + fn remove_user(&self, _user_id: i128) + { + // ... + } + } + + use crate as syrette; + + declare_interface!(UserManager -> IUserManager); + + #[async_trait] + impl AsyncInjectable for UserManager + { + async fn resolve( + _: &AsyncDIContainer, + _dependency_history: Vec<&'static str>, + ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError> + where + Self: Sized, + { + Ok(TransientPtr::new(Self::new())) + } + } + + pub trait INumber + { + fn get(&self) -> i32; + + fn set(&mut self, number: i32); + } + + impl PartialEq for dyn INumber + { + fn eq(&self, other: &Self) -> bool + { + self.get() == other.get() + } + } + + impl Debug for dyn INumber + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str(format!("{}", self.get()).as_str()) + } + } + + pub struct Number + { + pub num: i32, + } + + impl Number + { + pub fn new() -> Self + { + Self { num: 0 } + } + } + + impl INumber for Number + { + fn get(&self) -> i32 + { + self.num + } + + fn set(&mut self, number: i32) + { + self.num = number; + } + } + + declare_interface!(Number -> INumber, async = true); + + #[async_trait] + impl AsyncInjectable for Number + { + async fn resolve( + _: &AsyncDIContainer, + _dependency_history: Vec<&'static str>, + ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError> + where + Self: Sized, + { + Ok(TransientPtr::new(Self::new())) + } + } + } + + #[test] + fn can_bind_to() -> Result<(), Box<dyn Error>> + { + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + fn can_bind_to_transient() -> Result<(), Box<dyn Error>> + { + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_transient_scope(); + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>> + { + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_transient_scope() + .when_named("regular")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[tokio::test] + async fn can_bind_to_singleton() -> Result<(), Box<dyn Error>> + { + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_singleton_scope() + .await?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[tokio::test] + async fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>> + { + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_singleton_scope() + .await? + .when_named("cool")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + #[cfg(feature = "factory")] + fn can_bind_to_factory() -> Result<(), Box<dyn Error>> + { + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container.bind::<IUserManagerFactory>().to_factory(&|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); + + user_manager + })?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + #[cfg(feature = "factory")] + fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>> + { + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<IUserManagerFactory>() + .to_factory(&|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); + + user_manager + })? + .when_named("awesome")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[tokio::test] + async fn can_get() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + di_container + .bindings + .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider)); + + di_container + .get::<dyn subjects::IUserManager>() + .await? + .transient()?; + + Ok(()) + } + + #[tokio::test] + async fn can_get_named() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + di_container + .bindings + .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider)); + + di_container + .get_named::<dyn subjects::IUserManager>("special") + .await? + .transient()?; + + Ok(()) + } + + #[tokio::test] + async fn can_get_singleton() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new()); + + ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; + + mock_provider + .expect_provide() + .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + + di_container + .bindings + .set::<dyn subjects::INumber>(None, Box::new(mock_provider)); + + let first_number_rc = di_container + .get::<dyn subjects::INumber>() + .await? + .threadsafe_singleton()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container + .get::<dyn subjects::INumber>() + .await? + .threadsafe_singleton()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + + #[tokio::test] + async fn can_get_singleton_named() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new()); + + ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; + + mock_provider + .expect_provide() + .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + + di_container + .bindings + .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider)); + + let first_number_rc = di_container + .get_named::<dyn subjects::INumber>("cool") + .await? + .threadsafe_singleton()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container + .get_named::<dyn subjects::INumber>("cool") + .await? + .threadsafe_singleton()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + + #[tokio::test] + #[cfg(feature = "factory")] + async fn can_get_factory() -> Result<(), Box<dyn Error>> + { + trait IUserManager + { + fn add_user(&mut self, user_id: i128); + + fn remove_user(&mut self, user_id: i128); + } + + struct UserManager + { + users: Vec<i128>, + } + + impl UserManager + { + fn new(users: Vec<i128>) -> Self + { + Self { users } + } + } + + impl IUserManager for UserManager + { + fn add_user(&mut self, user_id: i128) + { + self.users.push(user_id); + } + + fn remove_user(&mut self, user_id: i128) + { + let user_index = + self.users.iter().position(|user| *user == user_id).unwrap(); + + self.users.remove(user_index); + } + } + + use crate as syrette; + + #[crate::factory(async = true)] + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>; + + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( + &|users| { + let user_manager: TransientPtr<dyn IUserManager> = + TransientPtr::new(UserManager::new(users)); + + user_manager + }, + )), + )) + }); + + di_container + .bindings + .set::<IUserManagerFactory>(None, Box::new(mock_provider)); + + di_container + .get::<IUserManagerFactory>() + .await? + .threadsafe_factory()?; + + Ok(()) + } + + #[tokio::test] + #[cfg(feature = "factory")] + async fn can_get_factory_named() -> Result<(), Box<dyn Error>> + { + trait IUserManager + { + fn add_user(&mut self, user_id: i128); + + fn remove_user(&mut self, user_id: i128); + } + + struct UserManager + { + users: Vec<i128>, + } + + impl UserManager + { + fn new(users: Vec<i128>) -> Self + { + Self { users } + } + } + + impl IUserManager for UserManager + { + fn add_user(&mut self, user_id: i128) + { + self.users.push(user_id); + } + + fn remove_user(&mut self, user_id: i128) + { + let user_index = + self.users.iter().position(|user| *user == user_id).unwrap(); + + self.users.remove(user_index); + } + } + + use crate as syrette; + + #[crate::factory(async = true)] + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>; + + mock! { + Provider {} + + #[async_trait] + impl IAsyncProvider for Provider + { + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; + } + } + + let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( + &|users| { + let user_manager: TransientPtr<dyn IUserManager> = + TransientPtr::new(UserManager::new(users)); + + user_manager + }, + )), + )) + }); + + di_container + .bindings + .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider)); + + di_container + .get_named::<IUserManagerFactory>("special") + .await? + .threadsafe_factory()?; + + Ok(()) + } +} diff --git a/src/castable_factory.rs b/src/castable_factory/blocking.rs index 5ff4db0..5ff4db0 100644 --- a/src/castable_factory.rs +++ b/src/castable_factory/blocking.rs diff --git a/src/castable_factory/mod.rs b/src/castable_factory/mod.rs new file mode 100644 index 0000000..530cc82 --- /dev/null +++ b/src/castable_factory/mod.rs @@ -0,0 +1,2 @@ +pub mod blocking; +pub mod threadsafe; diff --git a/src/castable_factory/threadsafe.rs b/src/castable_factory/threadsafe.rs new file mode 100644 index 0000000..7be055c --- /dev/null +++ b/src/castable_factory/threadsafe.rs @@ -0,0 +1,88 @@ +#![allow(clippy::module_name_repetitions)] +use crate::interfaces::any_factory::{AnyFactory, AnyThreadsafeFactory}; +use crate::interfaces::factory::IFactory; +use crate::ptr::TransientPtr; + +pub struct ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ + func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync), +} + +impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ + pub fn new( + func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + + Send + + Sync), + ) -> Self + { + Self { func } + } +} + +impl<Args, ReturnInterface> IFactory<Args, ReturnInterface> + for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ +} + +impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ + extern "rust-call" fn call(&self, args: Args) -> Self::Output + { + self.func.call(args) + } +} + +impl<Args, ReturnInterface> FnMut<Args> + for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ + extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output + { + self.call(args) + } +} + +impl<Args, ReturnInterface> FnOnce<Args> + for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ + type Output = TransientPtr<ReturnInterface>; + + extern "rust-call" fn call_once(self, args: Args) -> Self::Output + { + self.call(args) + } +} + +impl<Args, ReturnInterface> AnyFactory + for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ +} + +impl<Args, ReturnInterface> AnyThreadsafeFactory + for ThreadsafeCastableFactory<Args, ReturnInterface> +where + Args: 'static, + ReturnInterface: 'static + ?Sized, +{ +} diff --git a/src/di_container.rs b/src/di_container.rs index e42175b..b0e5af1 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -1,4 +1,4 @@ -//! Dependency injection container and other related utilities. +//! Dependency injection container. //! //! # Examples //! ``` @@ -53,7 +53,7 @@ use std::any::type_name; use std::marker::PhantomData; #[cfg(feature = "factory")] -use crate::castable_factory::CastableFactory; +use crate::castable_factory::blocking::CastableFactory; use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{ BindingBuilderError, @@ -63,7 +63,12 @@ use crate::errors::di_container::{ }; use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast::{CastBox, CastRc}; -use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; +use crate::provider::blocking::{ + IProvider, + Providable, + SingletonProvider, + TransientTypeProvider, +}; use crate::ptr::{SingletonPtr, SomePtr}; /// When configurator for a binding for type 'Interface' inside a [`DIContainer`]. @@ -256,7 +261,7 @@ where self.di_container.bindings.set::<Interface>( None, - Box::new(crate::provider::FactoryProvider::new( + Box::new(crate::provider::blocking::FactoryProvider::new( crate::ptr::FactoryPtr::new(factory_impl), )), ); @@ -290,7 +295,7 @@ where self.di_container.bindings.set::<Interface>( None, - Box::new(crate::provider::FactoryProvider::new( + Box::new(crate::provider::blocking::FactoryProvider::new( crate::ptr::FactoryPtr::new(factory_impl), )), ); @@ -302,7 +307,7 @@ where /// Dependency injection container. pub struct DIContainer { - bindings: DIContainerBindingMap, + bindings: DIContainerBindingMap<dyn IProvider>, } impl DIContainer @@ -416,7 +421,16 @@ impl DIContainer Interface: 'static + ?Sized, { self.bindings - .get::<Interface>(name)? + .get::<Interface>(name) + .map_or_else( + || { + Err(DIContainerError::BindingNotFound { + interface: type_name::<Interface>(), + name, + }) + }, + Ok, + )? .provide(self, dependency_history) .map_err(|err| DIContainerError::BindingResolveFailed { reason: err, @@ -442,7 +456,7 @@ mod tests use super::*; use crate::errors::injectable::InjectableError; - use crate::provider::IProvider; + use crate::provider::blocking::IProvider; use crate::ptr::TransientPtr; mod subjects diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 4df889d..4aa246e 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -1,10 +1,7 @@ -use std::any::{type_name, TypeId}; +use std::any::TypeId; use ahash::AHashMap; -use crate::errors::di_container::DIContainerError; -use crate::provider::IProvider; - #[derive(Debug, PartialEq, Eq, Hash)] struct DIContainerBindingKey { @@ -12,12 +9,16 @@ struct DIContainerBindingKey name: Option<&'static str>, } -pub struct DIContainerBindingMap +pub struct DIContainerBindingMap<Provider> +where + Provider: 'static + ?Sized, { - bindings: AHashMap<DIContainerBindingKey, Box<dyn IProvider>>, + bindings: AHashMap<DIContainerBindingKey, Box<Provider>>, } -impl DIContainerBindingMap +impl<Provider> DIContainerBindingMap<Provider> +where + Provider: 'static + ?Sized, { pub fn new() -> Self { @@ -26,33 +27,22 @@ impl DIContainerBindingMap } } - pub fn get<Interface>( - &self, - name: Option<&'static str>, - ) -> Result<&dyn IProvider, DIContainerError> + pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Provider> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::<Interface>(); - Ok(self - .bindings + self.bindings .get(&DIContainerBindingKey { type_id: interface_typeid, name, }) - .ok_or_else(|| DIContainerError::BindingNotFound { - interface: type_name::<Interface>(), - name, - })? - .as_ref()) + .map(|provider| provider.as_ref()) } - pub fn set<Interface>( - &mut self, - name: Option<&'static str>, - provider: Box<dyn IProvider>, - ) where + pub fn set<Interface>(&mut self, name: Option<&'static str>, provider: Box<Provider>) + where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::<Interface>(); @@ -69,7 +59,7 @@ impl DIContainerBindingMap pub fn remove<Interface>( &mut self, name: Option<&'static str>, - ) -> Option<Box<dyn IProvider>> + ) -> Option<Box<Provider>> where Interface: 'static + ?Sized, { diff --git a/src/errors/async_di_container.rs b/src/errors/async_di_container.rs new file mode 100644 index 0000000..4f5e50a --- /dev/null +++ b/src/errors/async_di_container.rs @@ -0,0 +1,79 @@ +//! Error types for [`AsyncDIContainer`] and it's related structs. +//! +//! --- +//! +//! *This module is only available if Syrette is built with the "async" feature.* +//! +//! [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer + +use crate::errors::injectable::InjectableError; + +/// Error type for [`AsyncDIContainer`]. +/// +/// [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer +#[derive(thiserror::Error, Debug)] +pub enum AsyncDIContainerError +{ + /// Unable to cast a binding for a interface. + #[error("Unable to cast binding for interface '{0}'")] + CastFailed(&'static str), + + /// Failed to resolve a binding for a interface. + #[error("Failed to resolve binding for interface '{interface}'")] + BindingResolveFailed + { + /// The reason for the problem. + #[source] + reason: InjectableError, + + /// The affected bound interface. + interface: &'static str, + }, + + /// No binding exists for a interface (and optionally a name). + #[error( + "No binding exists for interface '{interface}' {}", + .name.map_or_else(String::new, |name| format!("with name '{}'", name)) + )] + BindingNotFound + { + /// The interface that doesn't have a binding. + interface: &'static str, + + /// The name of the binding if one exists. + name: Option<&'static str>, + }, +} + +/// Error type for [`AsyncBindingBuilder`]. +/// +/// [`AsyncBindingBuilder`]: crate::async_di_container::AsyncBindingBuilder +#[derive(thiserror::Error, Debug)] +pub enum AsyncBindingBuilderError +{ + /// A binding already exists for a interface. + #[error("Binding already exists for interface '{0}'")] + BindingAlreadyExists(&'static str), +} + +/// Error type for [`AsyncBindingScopeConfigurator`]. +/// +/// [`AsyncBindingScopeConfigurator`]: crate::async_di_container::AsyncBindingScopeConfigurator +#[derive(thiserror::Error, Debug)] +pub enum AsyncBindingScopeConfiguratorError +{ + /// Resolving a singleton failed. + #[error("Resolving the given singleton failed")] + SingletonResolveFailed(#[from] InjectableError), +} + +/// Error type for [`AsyncBindingWhenConfigurator`]. +/// +/// [`AsyncBindingWhenConfigurator`]: crate::async_di_container::AsyncBindingWhenConfigurator +#[derive(thiserror::Error, Debug)] +pub enum AsyncBindingWhenConfiguratorError +{ + /// A binding for a interface wasn't found. + #[error("A binding for interface '{0}' wasn't found'")] + BindingNotFound(&'static str), +} diff --git a/src/errors/injectable.rs b/src/errors/injectable.rs index 4b9af96..ed161cb 100644 --- a/src/errors/injectable.rs +++ b/src/errors/injectable.rs @@ -3,7 +3,7 @@ //! //! [`Injectable`]: crate::interfaces::injectable::Injectable -use super::di_container::DIContainerError; +use crate::errors::di_container::DIContainerError; /// Error type for structs that implement [`Injectable`]. /// @@ -23,6 +23,18 @@ pub enum InjectableError affected: &'static str, }, + /// Failed to resolve dependencies. + #[cfg(feature = "async")] + #[error("Failed to resolve a dependency of '{affected}'")] + AsyncResolveFailed + { + /// The reason for the problem. + #[source] + reason: Box<crate::errors::async_di_container::AsyncDIContainerError>, + + /// The affected injectable type. + affected: &'static str, + }, /// Detected circular dependencies. #[error("Detected circular dependencies. {dependency_trace}")] DetectedCircular diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 7d66ddf..c3930b0 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -3,3 +3,6 @@ pub mod di_container; pub mod injectable; pub mod ptr; + +#[cfg(feature = "async")] +pub mod async_di_container; diff --git a/src/errors/ptr.rs b/src/errors/ptr.rs index e0c3d05..56621c1 100644 --- a/src/errors/ptr.rs +++ b/src/errors/ptr.rs @@ -17,3 +17,21 @@ pub enum SomePtrError found: &'static str, }, } + +/// Error type for [`SomeThreadsafePtr`]. +/// +/// [`SomeThreadsafePtr`]: crate::ptr::SomeThreadsafePtr +#[derive(thiserror::Error, Debug)] +pub enum SomeThreadsafePtrError +{ + /// Tried to get as a wrong threadsafe smart pointer type. + #[error("Wrong threadsafe smart pointer type. Expected {expected}, found {found}")] + WrongPtrType + { + /// The expected smart pointer type. + expected: &'static str, + + /// The found smart pointer type. + found: &'static str, + }, +} diff --git a/src/interfaces/any_factory.rs b/src/interfaces/any_factory.rs index 887bb61..1bf9208 100644 --- a/src/interfaces/any_factory.rs +++ b/src/interfaces/any_factory.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; -use crate::libs::intertrait::CastFrom; +use crate::libs::intertrait::{CastFrom, CastFromSync}; /// Interface for any factory to ever exist. pub trait AnyFactory: CastFrom {} @@ -14,3 +14,14 @@ impl Debug for dyn AnyFactory f.write_str("{}") } } + +/// Interface for any threadsafe factory to ever exist. +pub trait AnyThreadsafeFactory: CastFromSync {} + +impl Debug for dyn AnyThreadsafeFactory +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str("{}") + } +} diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs new file mode 100644 index 0000000..badc3c5 --- /dev/null +++ b/src/interfaces/async_injectable.rs @@ -0,0 +1,35 @@ +//! Interface for structs that can be injected into or be injected to. +//! +//! *This module is only available if Syrette is built with the "async" feature.* +use std::fmt::Debug; + +use async_trait::async_trait; + +use crate::async_di_container::AsyncDIContainer; +use crate::errors::injectable::InjectableError; +use crate::libs::intertrait::CastFromSync; +use crate::ptr::TransientPtr; + +/// Interface for structs that can be injected into or be injected to. +#[async_trait] +pub trait AsyncInjectable: CastFromSync +{ + /// Resolves the dependencies of the injectable. + /// + /// # Errors + /// Will return `Err` if resolving the dependencies fails. + async fn resolve( + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<TransientPtr<Self>, InjectableError> + where + Self: Sized; +} + +impl Debug for dyn AsyncInjectable +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str("{}") + } +} diff --git a/src/interfaces/mod.rs b/src/interfaces/mod.rs index 73dde04..ddb3bba 100644 --- a/src/interfaces/mod.rs +++ b/src/interfaces/mod.rs @@ -8,3 +8,6 @@ pub mod any_factory; #[cfg(feature = "factory")] pub mod factory; + +#[cfg(feature = "async")] +pub mod async_injectable; @@ -12,6 +12,11 @@ pub mod errors; pub mod interfaces; pub mod ptr; +#[cfg(feature = "async")] +pub mod async_di_container; + +#[cfg(feature = "async")] +pub use async_di_container::AsyncDIContainer; pub use di_container::DIContainer; pub use syrette_macros::*; @@ -75,9 +80,8 @@ macro_rules! di_container_bind { /// /// A default factory is a factory that doesn't take any arguments. /// -/// More tedious ways to accomplish what this macro does would either be by using -/// the [`factory`] macro or by manually declaring the interfaces -/// with the [`declare_interface`] macro. +/// The more tedious way to accomplish what this macro does would be by using +/// the [`factory`] macro. /// /// *This macro is only available if Syrette is built with the "factory" feature.* /// @@ -95,43 +99,19 @@ macro_rules! di_container_bind { /// /// declare_default_factory!(dyn IParser); /// ``` -/// -/// The expanded equivelent of this would be -/// -/// ``` -/// # use syrette::declare_default_factory; -/// # -/// trait IParser { -/// // Methods and etc here... -/// } -/// -/// syrette::declare_interface!( -/// syrette::castable_factory::CastableFactory< -/// (), -/// dyn IParser, -/// > -> syrette::interfaces::factory::IFactory<(), dyn IParser> -/// ); -/// -/// syrette::declare_interface!( -/// syrette::castable_factory::CastableFactory< -/// (), -/// dyn IParser, -/// > -> syrette::interfaces::any_factory::AnyFactory -/// ); -/// ``` #[macro_export] #[cfg(feature = "factory")] macro_rules! declare_default_factory { ($interface: ty) => { syrette::declare_interface!( - syrette::castable_factory::CastableFactory< + syrette::castable_factory::blocking::CastableFactory< (), $interface, > -> syrette::interfaces::factory::IFactory<(), $interface> ); syrette::declare_interface!( - syrette::castable_factory::CastableFactory< + syrette::castable_factory::blocking::CastableFactory< (), $interface, > -> syrette::interfaces::any_factory::AnyFactory diff --git a/src/libs/intertrait/mod.rs b/src/libs/intertrait/mod.rs index 2d62871..bdae4c7 100644 --- a/src/libs/intertrait/mod.rs +++ b/src/libs/intertrait/mod.rs @@ -23,7 +23,7 @@ //! MIT license (LICENSE-MIT or <http://opensource.org/licenses/MIT>) //! //! at your option. -use std::any::{Any, TypeId}; +use std::any::{type_name, Any, TypeId}; use std::rc::Rc; use std::sync::Arc; @@ -60,7 +60,10 @@ static CASTER_MAP: Lazy<AHashMap<(TypeId, TypeId), BoxedCaster>> = Lazy::new(|| fn cast_arc_panic<Trait: ?Sized + 'static>(_: Arc<dyn Any + Sync + Send>) -> Arc<Trait> { - panic!("Prepend [sync] to the list of target traits for Sync + Send types") + panic!( + "Interface trait '{}' has not been marked async", + type_name::<Trait>() + ) } /// A `Caster` knows how to cast a reference to or `Box` of a trait object for `Any` diff --git a/src/libs/mod.rs b/src/libs/mod.rs index 8d5583d..b1c7a74 100644 --- a/src/libs/mod.rs +++ b/src/libs/mod.rs @@ -1,3 +1,5 @@ pub mod intertrait; +#[cfg(feature = "async")] +pub extern crate async_trait; pub extern crate linkme; diff --git a/src/provider/async.rs b/src/provider/async.rs new file mode 100644 index 0000000..93ae03a --- /dev/null +++ b/src/provider/async.rs @@ -0,0 +1,135 @@ +#![allow(clippy::module_name_repetitions)] +use std::marker::PhantomData; + +use async_trait::async_trait; + +use crate::async_di_container::AsyncDIContainer; +use crate::errors::injectable::InjectableError; +use crate::interfaces::async_injectable::AsyncInjectable; +use crate::ptr::{ThreadsafeSingletonPtr, TransientPtr}; + +#[derive(strum_macros::Display, Debug)] +pub enum AsyncProvidable +{ + Transient(TransientPtr<dyn AsyncInjectable>), + Singleton(ThreadsafeSingletonPtr<dyn AsyncInjectable>), + #[cfg(feature = "factory")] + Factory( + crate::ptr::ThreadsafeFactoryPtr< + dyn crate::interfaces::any_factory::AnyThreadsafeFactory, + >, + ), +} + +#[async_trait] +pub trait IAsyncProvider: Send + Sync +{ + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError>; +} + +pub struct AsyncTransientTypeProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + injectable_phantom: PhantomData<InjectableType>, +} + +impl<InjectableType> AsyncTransientTypeProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + pub fn new() -> Self + { + Self { + injectable_phantom: PhantomData, + } + } +} + +#[async_trait] +impl<InjectableType> IAsyncProvider for AsyncTransientTypeProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + async fn provide( + &self, + di_container: &AsyncDIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError> + { + Ok(AsyncProvidable::Transient( + InjectableType::resolve(di_container, dependency_history).await?, + )) + } +} + +pub struct AsyncSingletonProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + singleton: ThreadsafeSingletonPtr<InjectableType>, +} + +impl<InjectableType> AsyncSingletonProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + pub fn new(singleton: ThreadsafeSingletonPtr<InjectableType>) -> Self + { + Self { singleton } + } +} + +#[async_trait] +impl<InjectableType> IAsyncProvider for AsyncSingletonProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + async fn provide( + &self, + _di_container: &AsyncDIContainer, + _dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError> + { + Ok(AsyncProvidable::Singleton(self.singleton.clone())) + } +} + +#[cfg(feature = "factory")] +pub struct AsyncFactoryProvider +{ + factory: crate::ptr::ThreadsafeFactoryPtr< + dyn crate::interfaces::any_factory::AnyThreadsafeFactory, + >, +} + +#[cfg(feature = "factory")] +impl AsyncFactoryProvider +{ + pub fn new( + factory: crate::ptr::ThreadsafeFactoryPtr< + dyn crate::interfaces::any_factory::AnyThreadsafeFactory, + >, + ) -> Self + { + Self { factory } + } +} + +#[cfg(feature = "factory")] +#[async_trait] +impl IAsyncProvider for AsyncFactoryProvider +{ + async fn provide( + &self, + _di_container: &AsyncDIContainer, + _dependency_history: Vec<&'static str>, + ) -> Result<AsyncProvidable, InjectableError> + { + Ok(AsyncProvidable::Factory(self.factory.clone())) + } +} diff --git a/src/provider.rs b/src/provider/blocking.rs index 13674b9..13674b9 100644 --- a/src/provider.rs +++ b/src/provider/blocking.rs diff --git a/src/provider/mod.rs b/src/provider/mod.rs new file mode 100644 index 0000000..7fb96bb --- /dev/null +++ b/src/provider/mod.rs @@ -0,0 +1,4 @@ +pub mod blocking; + +#[cfg(feature = "async")] +pub mod r#async; @@ -2,10 +2,11 @@ //! Smart pointer type aliases. use std::rc::Rc; +use std::sync::Arc; use paste::paste; -use crate::errors::ptr::SomePtrError; +use crate::errors::ptr::{SomePtrError, SomeThreadsafePtrError}; /// A smart pointer for a interface in the transient scope. pub type TransientPtr<Interface> = Box<Interface>; @@ -13,44 +14,34 @@ pub type TransientPtr<Interface> = Box<Interface>; /// A smart pointer to a interface in the singleton scope. pub type SingletonPtr<Interface> = Rc<Interface>; +/// A threadsafe smart pointer to a interface in the singleton scope. +pub type ThreadsafeSingletonPtr<Interface> = Arc<Interface>; + /// A smart pointer to a factory. #[cfg(feature = "factory")] pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>; -/// Some smart pointer. -#[derive(strum_macros::IntoStaticStr)] -pub enum SomePtr<Interface> -where - Interface: 'static + ?Sized, -{ - /// A smart pointer to a interface in the transient scope. - Transient(TransientPtr<Interface>), - - /// A smart pointer to a interface in the singleton scope. - Singleton(SingletonPtr<Interface>), - - /// A smart pointer to a factory. - #[cfg(feature = "factory")] - Factory(FactoryPtr<Interface>), -} +/// A threadsafe smart pointer to a factory. +#[cfg(feature = "factory")] +pub type ThreadsafeFactoryPtr<FactoryInterface> = Arc<FactoryInterface>; macro_rules! create_as_variant_fn { - ($variant: ident) => { + ($enum: ident, $variant: ident) => { paste! { #[doc = - "Returns as " [<$variant:lower>] ".\n" + "Returns as the `" [<$variant>] "` variant.\n" "\n" "# Errors\n" - "Will return Err if it's not a " [<$variant:lower>] "." + "Will return Err if it's not the `" [<$variant>] "` variant." ] - pub fn [<$variant:lower>](self) -> Result<[<$variant Ptr>]<Interface>, SomePtrError> + pub fn [<$variant:snake>](self) -> Result<[<$variant Ptr>]<Interface>, [<$enum Error>]> { - if let SomePtr::$variant(ptr) = self { + if let $enum::$variant(ptr) = self { return Ok(ptr); } - Err(SomePtrError::WrongPtrType { + Err([<$enum Error>]::WrongPtrType { expected: stringify!($variant), found: self.into() }) @@ -59,14 +50,60 @@ macro_rules! create_as_variant_fn { }; } +/// Some smart pointer. +#[derive(strum_macros::IntoStaticStr)] +pub enum SomePtr<Interface> +where + Interface: 'static + ?Sized, +{ + /// A smart pointer to a interface in the transient scope. + Transient(TransientPtr<Interface>), + + /// A smart pointer to a interface in the singleton scope. + Singleton(SingletonPtr<Interface>), + + /// A smart pointer to a factory. + #[cfg(feature = "factory")] + Factory(FactoryPtr<Interface>), +} + impl<Interface> SomePtr<Interface> where Interface: 'static + ?Sized, { - create_as_variant_fn!(Transient); + create_as_variant_fn!(SomePtr, Transient); + + create_as_variant_fn!(SomePtr, Singleton); + + #[cfg(feature = "factory")] + create_as_variant_fn!(SomePtr, Factory); +} + +/// Some threadsafe smart pointer. +#[derive(strum_macros::IntoStaticStr)] +pub enum SomeThreadsafePtr<Interface> +where + Interface: 'static + ?Sized, +{ + /// A smart pointer to a interface in the transient scope. + Transient(TransientPtr<Interface>), + + /// A smart pointer to a interface in the singleton scope. + ThreadsafeSingleton(ThreadsafeSingletonPtr<Interface>), + + /// A smart pointer to a factory. + #[cfg(feature = "factory")] + ThreadsafeFactory(ThreadsafeFactoryPtr<Interface>), +} + +impl<Interface> SomeThreadsafePtr<Interface> +where + Interface: 'static + ?Sized, +{ + create_as_variant_fn!(SomeThreadsafePtr, Transient); - create_as_variant_fn!(Singleton); + create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeSingleton); #[cfg(feature = "factory")] - create_as_variant_fn!(Factory); + create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeFactory); } |