From 8651f84f205da7a89f2fc7333d1dd8de0d80a22b Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 17 Sep 2022 16:12:45 +0200 Subject: refactor!: make async DI container be used inside of a Arc BREAKING CHANGE: The async DI container is to be used inside of a Arc & it also no longer implements Default --- src/async_di_container.rs | 466 +++++++++++++++++++++++++++++----------------- 1 file changed, 290 insertions(+), 176 deletions(-) (limited to 'src/async_di_container.rs') diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 7913c5a..0cd92a5 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -55,6 +55,9 @@ //! *This module is only available if Syrette is built with the "async" feature.* use std::any::type_name; use std::marker::PhantomData; +use std::sync::Arc; + +use tokio::sync::Mutex; #[cfg(feature = "factory")] use crate::castable_factory::threadsafe::ThreadsafeCastableFactory; @@ -77,19 +80,19 @@ use crate::provider::r#async::{ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; /// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingWhenConfigurator<'di_container, Interface> +pub struct AsyncBindingWhenConfigurator where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, } -impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface> +impl AsyncBindingWhenConfigurator where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -101,50 +104,45 @@ where /// /// # Errors /// Will return Err if no binding for the interface already exists. - pub fn when_named( - &mut self, + pub async fn when_named( + &self, name: &'static str, ) -> Result<(), AsyncBindingWhenConfiguratorError> { - let binding = self - .di_container - .bindings - .remove::(None) - .map_or_else( - || { - Err(AsyncBindingWhenConfiguratorError::BindingNotFound( - type_name::(), - )) - }, - Ok, - )?; - - self.di_container - .bindings - .set::(Some(name), binding); + let mut bindings_lock = self.di_container.bindings.lock().await; + + let binding = bindings_lock.remove::(None).map_or_else( + || { + Err(AsyncBindingWhenConfiguratorError::BindingNotFound( + type_name::(), + )) + }, + Ok, + )?; + + bindings_lock.set::(Some(name), binding); Ok(()) } } /// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +pub struct AsyncBindingScopeConfigurator where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, implementation_phantom: PhantomData, } -impl<'di_container, Interface, Implementation> - AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +impl AsyncBindingScopeConfigurator where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -156,14 +154,16 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator + pub async fn in_transient_scope(&self) -> AsyncBindingWhenConfigurator { - self.di_container.bindings.set::( + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock.set::( None, Box::new(AsyncTransientTypeProvider::::new()), ); - AsyncBindingWhenConfigurator::new(self.di_container) + AsyncBindingWhenConfigurator::new(self.di_container.clone()) } /// Configures the binding to be in a singleton scope. @@ -171,40 +171,41 @@ where /// # Errors /// Will return Err if resolving the implementation fails. pub async fn in_singleton_scope( - &mut self, + &self, ) -> Result, AsyncBindingScopeConfiguratorError> { let singleton: ThreadsafeSingletonPtr = ThreadsafeSingletonPtr::from( - Implementation::resolve(self.di_container, Vec::new()) + Implementation::resolve(&self.di_container, Vec::new()) .await .map_err( AsyncBindingScopeConfiguratorError::SingletonResolveFailed, )?, ); - self.di_container - .bindings + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock .set::(None, Box::new(AsyncSingletonProvider::new(singleton))); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. -pub struct AsyncBindingBuilder<'di_container, Interface> +pub struct AsyncBindingBuilder where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, } -impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface> +impl AsyncBindingBuilder where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -221,8 +222,8 @@ where /// # Errors /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. - pub fn to( - &mut self, + pub async fn to( + &self, ) -> Result< AsyncBindingScopeConfigurator, AsyncBindingBuilderError, @@ -230,17 +231,21 @@ where where Implementation: AsyncInjectable, { - if self.di_container.bindings.has::(None) { - return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< - Interface, - >( - ))); + { + let bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } } - let mut binding_scope_configurator = - AsyncBindingScopeConfigurator::new(self.di_container); + let binding_scope_configurator = + AsyncBindingScopeConfigurator::new(self.di_container.clone()); - binding_scope_configurator.in_transient_scope(); + binding_scope_configurator.in_transient_scope().await; Ok(binding_scope_configurator) } @@ -254,8 +259,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_factory( - &mut self, + pub async fn to_factory( + &self, factory_func: &'static (dyn Fn> + Send + Sync), @@ -265,7 +270,9 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory, { - if self.di_container.bindings.has::(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -274,14 +281,14 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::( + bindings_lock.set::( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -293,8 +300,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_default_factory( - &mut self, + pub async fn to_default_factory( + &self, factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr> + Send + Sync), @@ -302,7 +309,9 @@ where where Return: 'static + ?Sized, { - if self.di_container.bindings.has::(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -311,40 +320,40 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::( + bindings_lock.set::( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Dependency injection container. pub struct AsyncDIContainer { - bindings: DIContainerBindingMap, + bindings: Mutex>, } impl AsyncDIContainer { /// Returns a new `AsyncDIContainer`. #[must_use] - pub fn new() -> Self + pub fn new() -> Arc { - Self { - bindings: DIContainerBindingMap::new(), - } + Arc::new(Self { + bindings: Mutex::new(DIContainerBindingMap::new()), + }) } /// Returns a new [`AsyncBindingBuilder`] for the given interface. - pub fn bind(&mut self) -> AsyncBindingBuilder + pub fn bind(self: &mut Arc) -> AsyncBindingBuilder where Interface: 'static + ?Sized, { - AsyncBindingBuilder::::new(self) + AsyncBindingBuilder::::new(self.clone()) } /// Returns the type bound with `Interface`. @@ -355,7 +364,7 @@ impl AsyncDIContainer /// - Resolving the binding for fails /// - Casting the binding for fails pub async fn get( - &self, + self: &Arc, ) -> Result, AsyncDIContainerError> where Interface: 'static + ?Sized, @@ -371,7 +380,7 @@ impl AsyncDIContainer /// - Resolving the binding fails /// - Casting the binding for fails pub async fn get_named( - &self, + self: &Arc, name: &'static str, ) -> Result, AsyncDIContainerError> where @@ -382,7 +391,7 @@ impl AsyncDIContainer #[doc(hidden)] pub async fn get_bound( - &self, + self: &Arc, dependency_history: Vec<&'static str>, name: Option<&'static str>, ) -> Result, AsyncDIContainerError> @@ -471,24 +480,33 @@ impl AsyncDIContainer } async fn get_binding_providable( - &self, + self: &Arc, name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result where Interface: 'static + ?Sized, { - self.bindings - .get::(name) - .map_or_else( - || { - Err(AsyncDIContainerError::BindingNotFound { - interface: type_name::(), - name, - }) - }, - Ok, - )? + let provider; + + { + let bindings_lock = self.bindings.lock().await; + + provider = bindings_lock + .get::(name) + .map_or_else( + || { + Err(AsyncDIContainerError::BindingNotFound { + interface: type_name::(), + name, + }) + }, + Ok, + )? + .clone(); + } + + provider .provide(self, dependency_history) .await .map_err(|err| AsyncDIContainerError::BindingResolveFailed { @@ -498,14 +516,6 @@ impl AsyncDIContainer } } -impl Default for AsyncDIContainer -{ - fn default() -> Self - { - Self::new() - } -} - #[cfg(test)] mod tests { @@ -523,6 +533,7 @@ mod tests //! Test subjects. use std::fmt::Debug; + use std::sync::Arc; use async_trait::async_trait; use syrette_macros::declare_interface; @@ -569,7 +580,7 @@ mod tests impl AsyncInjectable for UserManager { async fn resolve( - _: &AsyncDIContainer, + _: &Arc, _dependency_history: Vec<&'static str>, ) -> Result, crate::errors::injectable::InjectableError> where @@ -634,7 +645,7 @@ mod tests impl AsyncInjectable for Number { async fn resolve( - _: &AsyncDIContainer, + _: &Arc, _dependency_history: Vec<&'static str>, ) -> Result, crate::errors::injectable::InjectableError> where @@ -645,53 +656,71 @@ mod tests } } - #[test] - fn can_bind_to() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()?; + .to::() + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to_transient() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? - .in_transient_scope(); + .to::() + .await? + .in_transient_scope() + .await; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient_when_named() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to_transient_when_named() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_transient_scope() - .when_named("regular")?; + .await + .when_named("regular") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -699,17 +728,22 @@ mod tests #[tokio::test] async fn can_bind_to_singleton() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_singleton_scope() .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -717,55 +751,70 @@ mod tests #[tokio::test] async fn can_bind_to_singleton_when_named() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_singleton_scope() .await? - .when_named("cool")?; + .when_named("cool") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory() -> Result<(), Box> + async fn can_bind_to_factory() -> Result<(), Box> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } - di_container.bind::().to_factory(&|| { - let user_manager: TransientPtr = - TransientPtr::new(subjects::UserManager::new()); + di_container + .bind::() + .to_factory(&|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); - user_manager - })?; + user_manager + }) + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory_when_named() -> Result<(), Box> + async fn can_bind_to_factory_when_named() -> Result<(), Box> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() @@ -774,10 +823,14 @@ mod tests TransientPtr::new(subjects::UserManager::new()); user_manager - })? - .when_named("awesome")?; + }) + .await? + .when_named("awesome") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -793,25 +846,37 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } di_container .get::() @@ -832,25 +897,40 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::(Some("special"), Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::( + Some("special"), + Box::new(mock_provider), + ); + } di_container .get_named::("special") @@ -871,13 +951,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -885,13 +967,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } let first_number_rc = di_container .get::() @@ -921,13 +1015,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -935,13 +1031,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::(Some("cool"), Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::(Some("cool"), Box::new(mock_provider)); + } let first_number_rc = di_container .get_named::("cool") @@ -1014,13 +1122,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -1037,9 +1147,13 @@ mod tests )) }); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } di_container .get::() -- cgit v1.2.3-18-g5258