diff options
author | HampusM <hampus@hampusmat.com> | 2022-09-17 18:33:43 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-09-17 18:33:43 +0200 |
commit | 7de7f73963a266cceff85d6ab71c3256e5d382ec (patch) | |
tree | 67575870945b7ed0a5eeb99ccba79327598b3e02 | |
parent | 8651f84f205da7a89f2fc7333d1dd8de0d80a22b (diff) |
feat!: allow factories to access async DI container
BREAKING CHANGE: The to_factory & to_default_factory methods of AsyncBindingBuilder now expects a function returning a factory function
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | examples/async/bootstrap.rs | 30 | ||||
-rw-r--r-- | examples/async/interfaces/mod.rs | 1 | ||||
-rw-r--r-- | examples/async/main.rs | 12 | ||||
-rw-r--r-- | macros/src/lib.rs | 32 | ||||
-rw-r--r-- | src/async_di_container.rs | 210 | ||||
-rw-r--r-- | src/provider/async.rs | 20 |
7 files changed, 211 insertions, 96 deletions
@@ -27,7 +27,7 @@ required-features = ["factory"] [[example]] name = "async" -required-features = ["async"] +required-features = ["async", "factory"] [dependencies] syrette_macros = { path = "./macros", version = "0.3.0" } diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs index 7e1d2cd..51af067 100644 --- a/examples/async/bootstrap.rs +++ b/examples/async/bootstrap.rs @@ -2,17 +2,20 @@ use std::sync::Arc; use anyhow::Result; use syrette::async_di_container::AsyncDIContainer; +use syrette::declare_default_factory; +use syrette::ptr::TransientPtr; -// Concrete implementations use crate::animals::cat::Cat; use crate::animals::dog::Dog; use crate::animals::human::Human; -// -// Interfaces +use crate::food::Food; use crate::interfaces::cat::ICat; use crate::interfaces::dog::IDog; +use crate::interfaces::food::{IFood, IFoodFactory}; use crate::interfaces::human::IHuman; +declare_default_factory!(dyn ICat, threadsafe = true); + pub async fn bootstrap() -> Result<Arc<AsyncDIContainer>> { let mut di_container = AsyncDIContainer::new(); @@ -24,8 +27,27 @@ pub async fn bootstrap() -> Result<Arc<AsyncDIContainer>> .in_singleton_scope() .await?; - di_container.bind::<dyn ICat>().to::<Cat>().await?; + di_container + .bind::<dyn ICat>() + .to_default_factory(&|_| { + let cat: TransientPtr<dyn ICat> = TransientPtr::new(Cat::new()); + + cat + }) + .await?; + di_container.bind::<dyn IHuman>().to::<Human>().await?; + di_container + .bind::<IFoodFactory>() + .to_factory(&|_| { + Box::new(|| { + let food: Box<dyn IFood> = Box::new(Food::new()); + + food + }) + }) + .await?; + Ok(di_container) } diff --git a/examples/async/interfaces/mod.rs b/examples/async/interfaces/mod.rs index 5444978..ea0a26d 100644 --- a/examples/async/interfaces/mod.rs +++ b/examples/async/interfaces/mod.rs @@ -1,3 +1,4 @@ pub mod cat; pub mod dog; +pub mod food; pub mod human; diff --git a/examples/async/main.rs b/examples/async/main.rs index 3c884fe..03e36e1 100644 --- a/examples/async/main.rs +++ b/examples/async/main.rs @@ -7,12 +7,15 @@ use tokio::spawn; mod animals; mod bootstrap; +mod food; mod interfaces; use bootstrap::bootstrap; use interfaces::dog::IDog; use interfaces::human::IHuman; +use crate::interfaces::food::IFoodFactory; + #[tokio::main] async fn main() -> Result<()> { @@ -29,6 +32,15 @@ async fn main() -> Result<()> dog.woof(); } + let food_factory = di_container + .get::<IFoodFactory>() + .await? + .threadsafe_factory()?; + + let food = food_factory(); + + food.eat(); + spawn(async move { let human = di_container.get::<dyn IHuman>().await?.transient()?; diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 4c815db..79b1a1b 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -194,7 +194,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); - let is_async = flags + let is_threadsafe = flags .iter() .find(|flag| flag.flag.to_string().as_str() == "threadsafe") .map_or(false, |flag| flag.is_on.value); @@ -202,24 +202,27 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke let factory_type_alias::FactoryTypeAlias { type_alias, factory_interface, - arg_types, - return_type, + arg_types: _, + return_type: _, } = parse(type_alias_stream).unwrap(); - let decl_interfaces = if is_async { + let decl_interfaces = if is_threadsafe { quote! { syrette::declare_interface!( syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< - #arg_types, - #return_type - > -> #factory_interface, + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), + #factory_interface + > -> syrette::interfaces::factory::IFactory< + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), + #factory_interface + >, async = true ); syrette::declare_interface!( syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< - #arg_types, - #return_type + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), + #factory_interface > -> syrette::interfaces::any_factory::AnyThreadsafeFactory, async = true ); @@ -300,17 +303,20 @@ pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream return quote! { syrette::declare_interface!( syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< - (), + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), #interface, - > -> syrette::interfaces::factory::IFactory<(), #interface>, + > -> syrette::interfaces::factory::IFactory< + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), + #interface + >, async = true ); syrette::declare_interface!( syrette::castable_factory::threadsafe::ThreadsafeCastableFactory< - (), + (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,), #interface, - > -> syrette::interfaces::any_factory::AnyFactory, + > -> syrette::interfaces::any_factory::AnyThreadsafeFactory, async = true ); } diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 0cd92a5..ef0a540 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -39,7 +39,8 @@ //! //! di_container //! .bind::<dyn IDatabaseService>() -//! .to::<DatabaseService>()?; +//! .to::<DatabaseService>() +//! .await?; //! //! let database_service = di_container //! .get::<dyn IDatabaseService>() @@ -259,16 +260,17 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub async fn to_factory<Args, Return>( + pub async fn to_factory<Args, Return, FactoryFunc>( &self, - factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>> - + Send - + Sync), + factory_func: &'static FactoryFunc, ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError> where Args: 'static, Return: 'static + ?Sized, - Interface: crate::interfaces::factory::IFactory<Args, Return>, + Interface: Fn<Args, Output = Return>, + FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = Box<(dyn Fn<Args, Output = Return>)>> + + Send + + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -285,6 +287,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + false, )), ); @@ -300,14 +303,15 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub async fn to_default_factory<Return>( + pub async fn to_default_factory<Return, FactoryFunc>( &self, - factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>> - + Send - + Sync), + factory_func: &'static FactoryFunc, ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError> where Return: 'static + ?Sized, + FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = crate::ptr::TransientPtr<Return>> + + Send + + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -324,6 +328,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + true, )), ); @@ -402,10 +407,11 @@ impl AsyncDIContainer .get_binding_providable::<Interface>(name, dependency_history) .await?; - Self::handle_binding_providable(binding_providable) + self.handle_binding_providable(binding_providable) } fn handle_binding_providable<Interface>( + self: &Arc<Self>, binding_providable: AsyncProvidable, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where @@ -444,37 +450,49 @@ impl AsyncDIContainer } #[cfg(feature = "factory")] AsyncProvidable::Factory(factory_binding) => { - match factory_binding.clone().cast::<Interface>() { - Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)), - Err(first_err) => { - use crate::interfaces::factory::IFactory; - - if let CastError::NotArcCastable(_) = first_err { - return Err(AsyncDIContainerError::InterfaceNotAsync( + let factory = factory_binding + .cast::<dyn crate::interfaces::factory::IFactory< + (Arc<AsyncDIContainer>,), + Interface, + >>() + .map_err(|err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync( type_name::<Interface>(), - )); + ) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed { + interface: type_name::<Interface>(), + binding_kind: "factory", + } } + })?; - let default_factory = factory_binding - .cast::<dyn IFactory<(), Interface>>() - .map_err(|err| match err { - CastError::NotArcCastable(_) => { - AsyncDIContainerError::InterfaceNotAsync(type_name::< - Interface, - >( - )) - } - CastError::CastFailed { from: _, to: _ } => { - AsyncDIContainerError::CastFailed { - interface: type_name::<Interface>(), - binding_kind: "factory", - } - } - })?; + Ok(SomeThreadsafePtr::ThreadsafeFactory( + factory(self.clone()).into(), + )) + } + AsyncProvidable::DefaultFactory(default_factory_binding) => { + use crate::interfaces::factory::IFactory; + + let default_factory = default_factory_binding + .cast::<dyn IFactory<(Arc<AsyncDIContainer>,), Interface>>() + .map_err(|err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync( + type_name::<Interface>(), + ) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed { + interface: type_name::<Interface>(), + binding_kind: "default factory", + } + } + })?; - Ok(SomeThreadsafePtr::Transient(default_factory())) - } - } + Ok(SomeThreadsafePtr::Transient(default_factory(self.clone()))) } } } @@ -788,11 +806,13 @@ mod tests di_container .bind::<IUserManagerFactory>() - .to_factory(&|| { - let user_manager: TransientPtr<dyn subjects::IUserManager> = - TransientPtr::new(subjects::UserManager::new()); + .to_factory(&|_| { + Box::new(|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); - user_manager + user_manager + }) }) .await?; @@ -818,11 +838,13 @@ mod tests di_container .bind::<IUserManagerFactory>() - .to_factory(&|| { - let user_manager: TransientPtr<dyn subjects::IUserManager> = - TransientPtr::new(subjects::UserManager::new()); + .to_factory(&|_| { + Box::new(|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); - user_manager + user_manager + }) }) .await? .when_named("awesome") @@ -1111,8 +1133,7 @@ mod tests use crate as syrette; #[crate::factory(threadsafe = true)] - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>; + type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>; mock! { Provider {} @@ -1130,21 +1151,37 @@ mod tests } } - let mut di_container = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( - &|users| { - let user_manager: TransientPtr<dyn IUserManager> = - TransientPtr::new(UserManager::new(users)); + mock_provider.expect_do_clone().returning(|| { + type FactoryFunc = Box< + (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>) + >; + + let mut inner_mock_provider = MockProvider::new(); - user_manager - }, - )), - )) + let factory_func: &'static (dyn Fn< + (Arc<AsyncDIContainer>,), + Output = FactoryFunc> + Send + Sync) = &|_| { + Box::new(|users| { + let user_manager: TransientPtr<dyn IUserManager> = + TransientPtr::new(UserManager::new(users)); + + user_manager + }) + }; + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new( + ThreadsafeCastableFactory::new(factory_func), + ), + )) + }); + + Box::new(inner_mock_provider) }); { @@ -1206,8 +1243,7 @@ mod tests use crate as syrette; #[crate::factory(threadsafe = true)] - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>; + type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>; mock! { Provider {} @@ -1217,32 +1253,54 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( - &|users| { - let user_manager: TransientPtr<dyn IUserManager> = - TransientPtr::new(UserManager::new(users)); + mock_provider.expect_do_clone().returning(|| { + type FactoryFunc = Box< + (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>) + >; - user_manager - }, - )), - )) + let mut inner_mock_provider = MockProvider::new(); + + let factory_func: &'static (dyn Fn< + (Arc<AsyncDIContainer>,), + Output = FactoryFunc> + Send + Sync) = &|_| { + Box::new(|users| { + let user_manager: TransientPtr<dyn IUserManager> = + TransientPtr::new(UserManager::new(users)); + + user_manager + }) + }; + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new( + ThreadsafeCastableFactory::new(factory_func), + ), + )) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider)); + } di_container .get_named::<IUserManagerFactory>("special") diff --git a/src/provider/async.rs b/src/provider/async.rs index 1ddb614..df96b27 100644 --- a/src/provider/async.rs +++ b/src/provider/async.rs @@ -20,6 +20,12 @@ pub enum AsyncProvidable dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, ), + #[cfg(feature = "factory")] + DefaultFactory( + crate::ptr::ThreadsafeFactoryPtr< + dyn crate::interfaces::any_factory::AnyThreadsafeFactory, + >, + ), } #[async_trait] @@ -150,6 +156,7 @@ pub struct AsyncFactoryProvider factory: crate::ptr::ThreadsafeFactoryPtr< dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, + is_default_factory: bool, } #[cfg(feature = "factory")] @@ -159,9 +166,13 @@ impl AsyncFactoryProvider factory: crate::ptr::ThreadsafeFactoryPtr< dyn crate::interfaces::any_factory::AnyThreadsafeFactory, >, + is_default_factory: bool, ) -> Self { - Self { factory } + Self { + factory, + is_default_factory, + } } } @@ -175,7 +186,11 @@ impl IAsyncProvider for AsyncFactoryProvider _dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError> { - Ok(AsyncProvidable::Factory(self.factory.clone())) + Ok(if self.is_default_factory { + AsyncProvidable::DefaultFactory(self.factory.clone()) + } else { + AsyncProvidable::Factory(self.factory.clone()) + }) } fn do_clone(&self) -> Box<dyn IAsyncProvider> @@ -191,6 +206,7 @@ impl Clone for AsyncFactoryProvider { Self { factory: self.factory.clone(), + is_default_factory: self.is_default_factory.clone(), } } } |