diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | examples/async-factory/main.rs | 62 | ||||
-rw-r--r-- | macros/src/factory/build_declare_interfaces.rs | 2 | ||||
-rw-r--r-- | macros/src/factory/declare_default_args.rs | 2 | ||||
-rw-r--r-- | macros/src/lib.rs | 39 | ||||
-rw-r--r-- | src/async_di_container.rs | 139 | ||||
-rw-r--r-- | src/castable_factory/mod.rs | 2 | ||||
-rw-r--r-- | src/castable_factory/threadsafe.rs | 4 | ||||
-rw-r--r-- | src/interfaces/factory.rs | 9 | ||||
-rw-r--r-- | src/lib.rs | 5 | ||||
-rw-r--r-- | src/provider/async.rs | 37 |
11 files changed, 246 insertions, 57 deletions
@@ -49,7 +49,7 @@ tokio = { version = "1.20.1", features = ["sync"], optional = true } mockall = "0.11.1" anyhow = "1.0.62" third-party-lib = { path = "./examples/with-3rd-party/third-party-lib" } -tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread", "time"] } [workspace] members = [ diff --git a/examples/async-factory/main.rs b/examples/async-factory/main.rs index 74e12c7..715abf5 100644 --- a/examples/async-factory/main.rs +++ b/examples/async-factory/main.rs @@ -2,10 +2,14 @@ #![deny(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] +use std::time::Duration; + use anyhow::Result; -use syrette::{async_closure, factory, AsyncDIContainer}; +use syrette::ptr::TransientPtr; +use syrette::{async_closure, declare_default_factory, factory, AsyncDIContainer}; +use tokio::time::sleep; -trait IFoo +trait IFoo: Send + Sync { fn bar(&self); } @@ -36,6 +40,34 @@ impl IFoo for Foo } } +trait IPerson: Send + Sync +{ + fn name(&self) -> String; +} + +struct Person +{ + name: String, +} + +impl Person +{ + fn new(name: String) -> Self + { + Self { name } + } +} + +impl IPerson for Person +{ + fn name(&self) -> String + { + self.name.clone() + } +} + +declare_default_factory!(dyn IPerson, async = true); + #[tokio::main] async fn main() -> Result<()> { @@ -45,9 +77,23 @@ async fn main() -> Result<()> .bind::<IFooFactory>() .to_async_factory(&|_| { async_closure!(|cnt| { - let foo = Box::new(Foo::new(cnt)); + let foo_ptr = Box::new(Foo::new(cnt)); - foo as Box<dyn IFoo> + foo_ptr as Box<dyn IFoo> + }) + }) + .await?; + + di_container + .bind::<dyn IPerson>() + .to_async_default_factory(&|_| { + async_closure!(|| { + // Do some time demanding thing... + sleep(Duration::from_secs(1)).await; + + let person = TransientPtr::new(Person::new("Bob".to_string())); + + person as TransientPtr<dyn IPerson> }) }) .await?; @@ -57,9 +103,13 @@ async fn main() -> Result<()> .await? .threadsafe_factory()?; - let foo = foo_factory(4).await; + let foo_ptr = foo_factory(4).await; + + foo_ptr.bar(); + + let person = di_container.get::<dyn IPerson>().await?.transient()?; - foo.bar(); + println!("Person name is {}", person.name()); Ok(()) } diff --git a/macros/src/factory/build_declare_interfaces.rs b/macros/src/factory/build_declare_interfaces.rs index ac4ddd6..61e162f 100644 --- a/macros/src/factory/build_declare_interfaces.rs +++ b/macros/src/factory/build_declare_interfaces.rs @@ -14,7 +14,7 @@ pub fn build_declare_factory_interfaces( syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), #factory_interface - > -> syrette::interfaces::factory::IFactory< + > -> syrette::interfaces::factory::IThreadsafeFactory< (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), #factory_interface >, diff --git a/macros/src/factory/declare_default_args.rs b/macros/src/factory/declare_default_args.rs index 6450583..d19eba8 100644 --- a/macros/src/factory/declare_default_args.rs +++ b/macros/src/factory/declare_default_args.rs @@ -5,7 +5,7 @@ use syn::{Token, Type}; use crate::macro_flag::MacroFlag; use crate::util::iterator_ext::IteratorExt; -pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"]; +pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe", "async"]; pub struct DeclareDefaultFactoryMacroArgs { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 07ee7a5..172a113 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -6,16 +6,20 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse, parse_macro_input, parse_str}; +use syn::{parse, parse_macro_input}; mod declare_interface_args; -mod factory; -mod fn_trait; mod injectable; mod libs; mod macro_flag; mod util; +#[cfg(feature = "factory")] +mod factory; + +#[cfg(feature = "factory")] +mod fn_trait; + use crate::declare_interface_args::DeclareInterfaceArgs; use crate::injectable::implementation::InjectableImpl; use crate::injectable::macro_args::InjectableMacroArgs; @@ -188,7 +192,7 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream { use quote::ToTokens; - use syn::Type; + use syn::{parse_str, Type}; use crate::factory::build_declare_interfaces::build_declare_factory_interfaces; use crate::factory::macro_args::FactoryMacroArgs; @@ -266,6 +270,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke /// /// # Flags /// - `threadsafe` - Mark as threadsafe. +/// - `async` - Mark as async. Infers the `threadsafe` flag. /// /// # Panics /// If the provided arguments are invalid. @@ -285,20 +290,40 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke #[cfg(feature = "factory")] pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream { + use syn::parse_str; + use crate::factory::build_declare_interfaces::build_declare_factory_interfaces; use crate::factory::declare_default_args::DeclareDefaultFactoryMacroArgs; use crate::fn_trait::FnTrait; let DeclareDefaultFactoryMacroArgs { interface, flags } = parse(args_stream).unwrap(); - let is_threadsafe = flags + let mut is_threadsafe = flags .iter() .find(|flag| flag.flag.to_string().as_str() == "threadsafe") .map_or(false, |flag| flag.is_on.value); + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + + if is_async { + is_threadsafe = true; + } + let mut factory_interface: FnTrait = parse( - quote! { - dyn Fn() -> syrette::ptr::TransientPtr<#interface> + if is_async { + quote! { + dyn Fn() -> syrette::future::BoxFuture< + 'static, + syrette::ptr::TransientPtr<#interface> + > + } + } else { + quote! { + dyn Fn() -> syrette::ptr::TransientPtr<#interface> + } } .into(), ) diff --git a/src/async_di_container.rs b/src/async_di_container.rs index d90cc0b..894b707 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -69,6 +69,7 @@ use crate::errors::async_di_container::{ AsyncBindingWhenConfiguratorError, AsyncDIContainerError, }; +use crate::future::BoxFuture; use crate::interfaces::async_injectable::AsyncInjectable; use crate::libs::intertrait::cast::error::CastError; use crate::libs::intertrait::cast::{CastArc, CastBox}; @@ -274,6 +275,8 @@ where > + Send + Sync, { + use crate::provider::r#async::AsyncFactoryVariant; + let mut bindings_lock = self.di_container.bindings.lock().await; if bindings_lock.has::<Interface>(None) { @@ -289,7 +292,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - false, + AsyncFactoryVariant::Normal, )), ); @@ -313,7 +316,8 @@ where where Args: 'static, Return: 'static + ?Sized, - Interface: Fn<Args, Output = crate::future::BoxFuture<'static, Return>>, + Interface: + Fn<Args, Output = crate::future::BoxFuture<'static, Return>> + Send + Sync, FactoryFunc: Fn< (Arc<AsyncDIContainer>,), Output = Box< @@ -324,6 +328,8 @@ where > + Send + Sync, { + use crate::provider::r#async::AsyncFactoryVariant; + let mut bindings_lock = self.di_container.bindings.lock().await; if bindings_lock.has::<Interface>(None) { @@ -339,7 +345,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - false, + AsyncFactoryVariant::Normal, )), ); @@ -369,6 +375,58 @@ where > + Send + Sync, { + use crate::provider::r#async::AsyncFactoryVariant; + + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::<Interface>(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } + + let factory_impl = ThreadsafeCastableFactory::new(factory_func); + + bindings_lock.set::<Interface>( + None, + Box::new(crate::provider::r#async::AsyncFactoryProvider::new( + crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + AsyncFactoryVariant::Default, + )), + ); + + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) + } + + /// Creates a binding of factory type `Interface` to a async factory inside of the + /// associated [`AsyncDIContainer`]. + /// + /// *This function is only available if Syrette is built with the "factory" and + /// "async" features.* + /// + /// # Errors + /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for + /// the interface. + #[cfg(all(feature = "factory", feature = "async"))] + pub async fn to_async_default_factory<Return, FactoryFunc>( + &self, + factory_func: &'static FactoryFunc, + ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError> + where + Return: 'static + ?Sized, + FactoryFunc: Fn< + (Arc<AsyncDIContainer>,), + Output = Box< + (dyn Fn<(), Output = crate::future::BoxFuture<'static, Return>> + + Send + + Sync), + >, + > + Send + + Sync, + { + use crate::provider::r#async::AsyncFactoryVariant; + let mut bindings_lock = self.di_container.bindings.lock().await; if bindings_lock.has::<Interface>(None) { @@ -384,7 +442,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - true, + AsyncFactoryVariant::AsyncDefault, )), ); @@ -464,10 +522,10 @@ impl AsyncDIContainer .get_binding_providable::<Interface>(name, dependency_history) .await?; - self.handle_binding_providable(binding_providable) + self.handle_binding_providable(binding_providable).await } - fn handle_binding_providable<Interface>( + async fn handle_binding_providable<Interface>( self: &Arc<Self>, binding_providable: AsyncProvidable, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> @@ -507,11 +565,10 @@ impl AsyncDIContainer } #[cfg(feature = "factory")] AsyncProvidable::Factory(factory_binding) => { + use crate::interfaces::factory::IThreadsafeFactory; + let factory = factory_binding - .cast::<dyn crate::interfaces::factory::IFactory< - (Arc<AsyncDIContainer>,), - Interface, - >>() + .cast::<dyn IThreadsafeFactory<(Arc<AsyncDIContainer>,), Interface>>() .map_err(|err| match err { CastError::NotArcCastable(_) => { AsyncDIContainerError::InterfaceNotAsync( @@ -531,33 +588,59 @@ impl AsyncDIContainer )) } #[cfg(feature = "factory")] - AsyncProvidable::DefaultFactory(default_factory_binding) => { - use crate::interfaces::factory::IFactory; + AsyncProvidable::DefaultFactory(binding) => { + use crate::interfaces::factory::IThreadsafeFactory; - let default_factory = default_factory_binding - .cast::<dyn IFactory< + let default_factory = Self::cast_factory_binding::< + dyn IThreadsafeFactory< (Arc<AsyncDIContainer>,), dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync, - >>() - .map_err(|err| match err { - CastError::NotArcCastable(_) => { - AsyncDIContainerError::InterfaceNotAsync( - type_name::<Interface>(), - ) - } - CastError::CastFailed { from: _, to: _ } => { - AsyncDIContainerError::CastFailed { - interface: type_name::<Interface>(), - binding_kind: "default factory", - } - } - })?; + >, + >(binding, "default factory")?; Ok(SomeThreadsafePtr::Transient(default_factory(self.clone())())) } + #[cfg(feature = "factory")] + AsyncProvidable::AsyncDefaultFactory(binding) => { + use crate::interfaces::factory::IThreadsafeFactory; + + let async_default_factory = Self::cast_factory_binding::< + dyn IThreadsafeFactory< + (Arc<AsyncDIContainer>,), + dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>> + + Send + + Sync, + >, + >( + binding, "async default factory" + )?; + + Ok(SomeThreadsafePtr::Transient( + async_default_factory(self.clone())().await, + )) + } } } + #[cfg(feature = "factory")] + fn cast_factory_binding<Type: 'static + ?Sized>( + factory_binding: Arc<dyn crate::interfaces::any_factory::AnyThreadsafeFactory>, + binding_kind: &'static str, + ) -> Result<Arc<Type>, AsyncDIContainerError> + { + factory_binding.cast::<Type>().map_err(|err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync(type_name::<Type>()) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed { + interface: type_name::<Type>(), + binding_kind, + } + } + }) + } + async fn get_binding_providable<Interface>( self: &Arc<Self>, name: Option<&'static str>, diff --git a/src/castable_factory/mod.rs b/src/castable_factory/mod.rs index 530cc82..e81b842 100644 --- a/src/castable_factory/mod.rs +++ b/src/castable_factory/mod.rs @@ -1,2 +1,4 @@ pub mod blocking; + +#[cfg(feature = "async")] pub mod threadsafe; diff --git a/src/castable_factory/threadsafe.rs b/src/castable_factory/threadsafe.rs index 7be055c..b91dceb 100644 --- a/src/castable_factory/threadsafe.rs +++ b/src/castable_factory/threadsafe.rs @@ -1,6 +1,6 @@ #![allow(clippy::module_name_repetitions)] use crate::interfaces::any_factory::{AnyFactory, AnyThreadsafeFactory}; -use crate::interfaces::factory::IFactory; +use crate::interfaces::factory::IThreadsafeFactory; use crate::ptr::TransientPtr; pub struct ThreadsafeCastableFactory<Args, ReturnInterface> @@ -26,7 +26,7 @@ where } } -impl<Args, ReturnInterface> IFactory<Args, ReturnInterface> +impl<Args, ReturnInterface> IThreadsafeFactory<Args, ReturnInterface> for ThreadsafeCastableFactory<Args, ReturnInterface> where Args: 'static, diff --git a/src/interfaces/factory.rs b/src/interfaces/factory.rs index b09db36..de1fca9 100644 --- a/src/interfaces/factory.rs +++ b/src/interfaces/factory.rs @@ -9,3 +9,12 @@ where ReturnInterface: 'static + ?Sized, { } + +/// Interface for a threadsafe factory. +#[cfg(feature = "async")] +pub trait IThreadsafeFactory<Args, ReturnInterface>: + Fn<Args, Output = TransientPtr<ReturnInterface>> + crate::libs::intertrait::CastFromSync +where + ReturnInterface: 'static + ?Sized, +{ +} @@ -125,4 +125,9 @@ macro_rules! async_closure { Box::pin(async move { $($inner)* }) }) }; + (|| { $($inner: stmt);* }) => { + Box::new(|| { + Box::pin(async move { $($inner)* }) + }) + }; } diff --git a/src/provider/async.rs b/src/provider/async.rs index df96b27..c9a5273 100644 --- a/src/provider/async.rs +++ b/src/provider/async.rs @@ -26,6 +26,12 @@ pub enum AsyncProvidable dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, ), + #[cfg(feature = "factory")] + AsyncDefaultFactory( + crate::ptr::ThreadsafeFactoryPtr< + dyn crate::interfaces::any_factory::AnyThreadsafeFactory, + >, + ), } #[async_trait] @@ -150,13 +156,21 @@ where } } +#[derive(Clone, Copy)] +pub enum AsyncFactoryVariant +{ + Normal, + Default, + AsyncDefault, +} + #[cfg(feature = "factory")] pub struct AsyncFactoryProvider { factory: crate::ptr::ThreadsafeFactoryPtr< dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, - is_default_factory: bool, + variant: AsyncFactoryVariant, } #[cfg(feature = "factory")] @@ -166,13 +180,10 @@ impl AsyncFactoryProvider factory: crate::ptr::ThreadsafeFactoryPtr< dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, - is_default_factory: bool, + variant: AsyncFactoryVariant, ) -> Self { - Self { - factory, - is_default_factory, - } + Self { factory, variant } } } @@ -186,10 +197,14 @@ impl IAsyncProvider for AsyncFactoryProvider _dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError> { - Ok(if self.is_default_factory { - AsyncProvidable::DefaultFactory(self.factory.clone()) - } else { - AsyncProvidable::Factory(self.factory.clone()) + Ok(match self.variant { + AsyncFactoryVariant::Normal => AsyncProvidable::Factory(self.factory.clone()), + AsyncFactoryVariant::Default => { + AsyncProvidable::DefaultFactory(self.factory.clone()) + } + AsyncFactoryVariant::AsyncDefault => { + AsyncProvidable::AsyncDefaultFactory(self.factory.clone()) + } }) } @@ -206,7 +221,7 @@ impl Clone for AsyncFactoryProvider { Self { factory: self.factory.clone(), - is_default_factory: self.is_default_factory.clone(), + variant: self.variant, } } } |