From 7de7f73963a266cceff85d6ab71c3256e5d382ec Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 17 Sep 2022 18:33:43 +0200 Subject: feat!: allow factories to access async DI container BREAKING CHANGE: The to_factory & to_default_factory methods of AsyncBindingBuilder now expects a function returning a factory function --- src/async_di_container.rs | 210 +++++++++++++++++++++++++++++----------------- 1 file changed, 134 insertions(+), 76 deletions(-) (limited to 'src/async_di_container.rs') diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 0cd92a5..ef0a540 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -39,7 +39,8 @@ //! //! di_container //! .bind::() -//! .to::()?; +//! .to::() +//! .await?; //! //! let database_service = di_container //! .get::() @@ -259,16 +260,17 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub async fn to_factory( + pub async fn to_factory( &self, - factory_func: &'static (dyn Fn> - + Send - + Sync), + factory_func: &'static FactoryFunc, ) -> Result, AsyncBindingBuilderError> where Args: 'static, Return: 'static + ?Sized, - Interface: crate::interfaces::factory::IFactory, + Interface: Fn, + FactoryFunc: Fn<(Arc,), Output = Box<(dyn Fn)>> + + Send + + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -285,6 +287,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + false, )), ); @@ -300,14 +303,15 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub async fn to_default_factory( + pub async fn to_default_factory( &self, - factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr> - + Send - + Sync), + factory_func: &'static FactoryFunc, ) -> Result, AsyncBindingBuilderError> where Return: 'static + ?Sized, + FactoryFunc: Fn<(Arc,), Output = crate::ptr::TransientPtr> + + Send + + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -324,6 +328,7 @@ where None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), + true, )), ); @@ -402,10 +407,11 @@ impl AsyncDIContainer .get_binding_providable::(name, dependency_history) .await?; - Self::handle_binding_providable(binding_providable) + self.handle_binding_providable(binding_providable) } fn handle_binding_providable( + self: &Arc, binding_providable: AsyncProvidable, ) -> Result, AsyncDIContainerError> where @@ -444,37 +450,49 @@ impl AsyncDIContainer } #[cfg(feature = "factory")] AsyncProvidable::Factory(factory_binding) => { - match factory_binding.clone().cast::() { - Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)), - Err(first_err) => { - use crate::interfaces::factory::IFactory; - - if let CastError::NotArcCastable(_) = first_err { - return Err(AsyncDIContainerError::InterfaceNotAsync( + let factory = factory_binding + .cast::,), + Interface, + >>() + .map_err(|err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync( type_name::(), - )); + ) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed { + interface: type_name::(), + binding_kind: "factory", + } } + })?; - let default_factory = factory_binding - .cast::>() - .map_err(|err| match err { - CastError::NotArcCastable(_) => { - AsyncDIContainerError::InterfaceNotAsync(type_name::< - Interface, - >( - )) - } - CastError::CastFailed { from: _, to: _ } => { - AsyncDIContainerError::CastFailed { - interface: type_name::(), - binding_kind: "factory", - } - } - })?; + Ok(SomeThreadsafePtr::ThreadsafeFactory( + factory(self.clone()).into(), + )) + } + AsyncProvidable::DefaultFactory(default_factory_binding) => { + use crate::interfaces::factory::IFactory; + + let default_factory = default_factory_binding + .cast::,), Interface>>() + .map_err(|err| match err { + CastError::NotArcCastable(_) => { + AsyncDIContainerError::InterfaceNotAsync( + type_name::(), + ) + } + CastError::CastFailed { from: _, to: _ } => { + AsyncDIContainerError::CastFailed { + interface: type_name::(), + binding_kind: "default factory", + } + } + })?; - Ok(SomeThreadsafePtr::Transient(default_factory())) - } - } + Ok(SomeThreadsafePtr::Transient(default_factory(self.clone()))) } } } @@ -788,11 +806,13 @@ mod tests di_container .bind::() - .to_factory(&|| { - let user_manager: TransientPtr = - TransientPtr::new(subjects::UserManager::new()); + .to_factory(&|_| { + Box::new(|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); - user_manager + user_manager + }) }) .await?; @@ -818,11 +838,13 @@ mod tests di_container .bind::() - .to_factory(&|| { - let user_manager: TransientPtr = - TransientPtr::new(subjects::UserManager::new()); + .to_factory(&|_| { + Box::new(|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); - user_manager + user_manager + }) }) .await? .when_named("awesome") @@ -1111,8 +1133,7 @@ mod tests use crate as syrette; #[crate::factory(threadsafe = true)] - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(Vec,), dyn IUserManager>; + type IUserManagerFactory = dyn Fn(Vec) -> TransientPtr; mock! { Provider {} @@ -1130,21 +1151,37 @@ mod tests } } - let mut di_container = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( - &|users| { - let user_manager: TransientPtr = - TransientPtr::new(UserManager::new(users)); + mock_provider.expect_do_clone().returning(|| { + type FactoryFunc = Box< + (dyn Fn<(Vec,), Output = TransientPtr>) + >; + + let mut inner_mock_provider = MockProvider::new(); - user_manager - }, - )), - )) + let factory_func: &'static (dyn Fn< + (Arc,), + Output = FactoryFunc> + Send + Sync) = &|_| { + Box::new(|users| { + let user_manager: TransientPtr = + TransientPtr::new(UserManager::new(users)); + + user_manager + }) + }; + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new( + ThreadsafeCastableFactory::new(factory_func), + ), + )) + }); + + Box::new(inner_mock_provider) }); { @@ -1206,8 +1243,7 @@ mod tests use crate as syrette; #[crate::factory(threadsafe = true)] - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(Vec,), dyn IUserManager>; + type IUserManagerFactory = dyn Fn(Vec) -> TransientPtr; mock! { Provider {} @@ -1217,32 +1253,54 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new( - &|users| { - let user_manager: TransientPtr = - TransientPtr::new(UserManager::new(users)); + mock_provider.expect_do_clone().returning(|| { + type FactoryFunc = Box< + (dyn Fn<(Vec,), Output = TransientPtr>) + >; - user_manager - }, - )), - )) + let mut inner_mock_provider = MockProvider::new(); + + let factory_func: &'static (dyn Fn< + (Arc,), + Output = FactoryFunc> + Send + Sync) = &|_| { + Box::new(|users| { + let user_manager: TransientPtr = + TransientPtr::new(UserManager::new(users)); + + user_manager + }) + }; + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Factory( + crate::ptr::ThreadsafeFactoryPtr::new( + ThreadsafeCastableFactory::new(factory_func), + ), + )) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::(Some("special"), Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::(Some("special"), Box::new(mock_provider)); + } di_container .get_named::("special") -- cgit v1.2.3-18-g5258