diff options
author | HampusM <hampus@hampusmat.com> | 2022-09-17 16:12:45 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-09-17 16:12:45 +0200 |
commit | 8651f84f205da7a89f2fc7333d1dd8de0d80a22b (patch) | |
tree | a178427abb442e897d21f654db71cc8135236920 /src | |
parent | c1e682c25c24be3174d44ceb95b0537c38299d0c (diff) |
refactor!: make async DI container be used inside of a Arc
BREAKING CHANGE: The async DI container is to be used inside of a Arc & it also no longer implements Default
Diffstat (limited to 'src')
-rw-r--r-- | src/async_di_container.rs | 466 | ||||
-rw-r--r-- | src/di_container_binding_map.rs | 13 | ||||
-rw-r--r-- | src/interfaces/async_injectable.rs | 3 | ||||
-rw-r--r-- | src/provider/async.rs | 69 |
4 files changed, 363 insertions, 188 deletions
diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 7913c5a..0cd92a5 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -55,6 +55,9 @@ //! *This module is only available if Syrette is built with the "async" feature.* use std::any::type_name; use std::marker::PhantomData; +use std::sync::Arc; + +use tokio::sync::Mutex; #[cfg(feature = "factory")] use crate::castable_factory::threadsafe::ThreadsafeCastableFactory; @@ -77,19 +80,19 @@ use crate::provider::r#async::{ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; /// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingWhenConfigurator<'di_container, Interface> +pub struct AsyncBindingWhenConfigurator<Interface> where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc<AsyncDIContainer>, interface_phantom: PhantomData<Interface>, } -impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface> +impl<Interface> AsyncBindingWhenConfigurator<Interface> where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc<AsyncDIContainer>) -> Self { Self { di_container, @@ -101,50 +104,45 @@ where /// /// # Errors /// Will return Err if no binding for the interface already exists. - pub fn when_named( - &mut self, + pub async fn when_named( + &self, name: &'static str, ) -> Result<(), AsyncBindingWhenConfiguratorError> { - let binding = self - .di_container - .bindings - .remove::<Interface>(None) - .map_or_else( - || { - Err(AsyncBindingWhenConfiguratorError::BindingNotFound( - type_name::<Interface>(), - )) - }, - Ok, - )?; - - self.di_container - .bindings - .set::<Interface>(Some(name), binding); + let mut bindings_lock = self.di_container.bindings.lock().await; + + let binding = bindings_lock.remove::<Interface>(None).map_or_else( + || { + Err(AsyncBindingWhenConfiguratorError::BindingNotFound( + type_name::<Interface>(), + )) + }, + Ok, + )?; + + bindings_lock.set::<Interface>(Some(name), binding); Ok(()) } } /// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +pub struct AsyncBindingScopeConfigurator<Interface, Implementation> where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc<AsyncDIContainer>, interface_phantom: PhantomData<Interface>, implementation_phantom: PhantomData<Implementation>, } -impl<'di_container, Interface, Implementation> - AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +impl<Interface, Implementation> AsyncBindingScopeConfigurator<Interface, Implementation> where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc<AsyncDIContainer>) -> Self { Self { di_container, @@ -156,14 +154,16 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator<Interface> + pub async fn in_transient_scope(&self) -> AsyncBindingWhenConfigurator<Interface> { - self.di_container.bindings.set::<Interface>( + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock.set::<Interface>( None, Box::new(AsyncTransientTypeProvider::<Implementation>::new()), ); - AsyncBindingWhenConfigurator::new(self.di_container) + AsyncBindingWhenConfigurator::new(self.di_container.clone()) } /// Configures the binding to be in a singleton scope. @@ -171,40 +171,41 @@ where /// # Errors /// Will return Err if resolving the implementation fails. pub async fn in_singleton_scope( - &mut self, + &self, ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingScopeConfiguratorError> { let singleton: ThreadsafeSingletonPtr<Implementation> = ThreadsafeSingletonPtr::from( - Implementation::resolve(self.di_container, Vec::new()) + Implementation::resolve(&self.di_container, Vec::new()) .await .map_err( AsyncBindingScopeConfiguratorError::SingletonResolveFailed, )?, ); - self.di_container - .bindings + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock .set::<Interface>(None, Box::new(AsyncSingletonProvider::new(singleton))); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. -pub struct AsyncBindingBuilder<'di_container, Interface> +pub struct AsyncBindingBuilder<Interface> where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc<AsyncDIContainer>, interface_phantom: PhantomData<Interface>, } -impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface> +impl<Interface> AsyncBindingBuilder<Interface> where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc<AsyncDIContainer>) -> Self { Self { di_container, @@ -221,8 +222,8 @@ where /// # Errors /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. - pub fn to<Implementation>( - &mut self, + pub async fn to<Implementation>( + &self, ) -> Result< AsyncBindingScopeConfigurator<Interface, Implementation>, AsyncBindingBuilderError, @@ -230,17 +231,21 @@ where where Implementation: AsyncInjectable, { - if self.di_container.bindings.has::<Interface>(None) { - return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< - Interface, - >( - ))); + { + let bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::<Interface>(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } } - let mut binding_scope_configurator = - AsyncBindingScopeConfigurator::new(self.di_container); + let binding_scope_configurator = + AsyncBindingScopeConfigurator::new(self.di_container.clone()); - binding_scope_configurator.in_transient_scope(); + binding_scope_configurator.in_transient_scope().await; Ok(binding_scope_configurator) } @@ -254,8 +259,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_factory<Args, Return>( - &mut self, + pub async fn to_factory<Args, Return>( + &self, factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>> + Send + Sync), @@ -265,7 +270,9 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory<Args, Return>, { - if self.di_container.bindings.has::<Interface>(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::<Interface>(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -274,14 +281,14 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::<Interface>( + bindings_lock.set::<Interface>( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -293,8 +300,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_default_factory<Return>( - &mut self, + pub async fn to_default_factory<Return>( + &self, factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>> + Send + Sync), @@ -302,7 +309,9 @@ where where Return: 'static + ?Sized, { - if self.di_container.bindings.has::<Interface>(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::<Interface>(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -311,40 +320,40 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::<Interface>( + bindings_lock.set::<Interface>( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Dependency injection container. pub struct AsyncDIContainer { - bindings: DIContainerBindingMap<dyn IAsyncProvider>, + bindings: Mutex<DIContainerBindingMap<dyn IAsyncProvider>>, } impl AsyncDIContainer { /// Returns a new `AsyncDIContainer`. #[must_use] - pub fn new() -> Self + pub fn new() -> Arc<Self> { - Self { - bindings: DIContainerBindingMap::new(), - } + Arc::new(Self { + bindings: Mutex::new(DIContainerBindingMap::new()), + }) } /// Returns a new [`AsyncBindingBuilder`] for the given interface. - pub fn bind<Interface>(&mut self) -> AsyncBindingBuilder<Interface> + pub fn bind<Interface>(self: &mut Arc<Self>) -> AsyncBindingBuilder<Interface> where Interface: 'static + ?Sized, { - AsyncBindingBuilder::<Interface>::new(self) + AsyncBindingBuilder::<Interface>::new(self.clone()) } /// Returns the type bound with `Interface`. @@ -355,7 +364,7 @@ impl AsyncDIContainer /// - Resolving the binding for fails /// - Casting the binding for fails pub async fn get<Interface>( - &self, + self: &Arc<Self>, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where Interface: 'static + ?Sized, @@ -371,7 +380,7 @@ impl AsyncDIContainer /// - Resolving the binding fails /// - Casting the binding for fails pub async fn get_named<Interface>( - &self, + self: &Arc<Self>, name: &'static str, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where @@ -382,7 +391,7 @@ impl AsyncDIContainer #[doc(hidden)] pub async fn get_bound<Interface>( - &self, + self: &Arc<Self>, dependency_history: Vec<&'static str>, name: Option<&'static str>, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> @@ -471,24 +480,33 @@ impl AsyncDIContainer } async fn get_binding_providable<Interface>( - &self, + self: &Arc<Self>, name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, AsyncDIContainerError> where Interface: 'static + ?Sized, { - self.bindings - .get::<Interface>(name) - .map_or_else( - || { - Err(AsyncDIContainerError::BindingNotFound { - interface: type_name::<Interface>(), - name, - }) - }, - Ok, - )? + let provider; + + { + let bindings_lock = self.bindings.lock().await; + + provider = bindings_lock + .get::<Interface>(name) + .map_or_else( + || { + Err(AsyncDIContainerError::BindingNotFound { + interface: type_name::<Interface>(), + name, + }) + }, + Ok, + )? + .clone(); + } + + provider .provide(self, dependency_history) .await .map_err(|err| AsyncDIContainerError::BindingResolveFailed { @@ -498,14 +516,6 @@ impl AsyncDIContainer } } -impl Default for AsyncDIContainer -{ - fn default() -> Self - { - Self::new() - } -} - #[cfg(test)] mod tests { @@ -523,6 +533,7 @@ mod tests //! Test subjects. use std::fmt::Debug; + use std::sync::Arc; use async_trait::async_trait; use syrette_macros::declare_interface; @@ -569,7 +580,7 @@ mod tests impl AsyncInjectable for UserManager { async fn resolve( - _: &AsyncDIContainer, + _: &Arc<AsyncDIContainer>, _dependency_history: Vec<&'static str>, ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError> where @@ -634,7 +645,7 @@ mod tests impl AsyncInjectable for Number { async fn resolve( - _: &AsyncDIContainer, + _: &Arc<AsyncDIContainer>, _dependency_history: Vec<&'static str>, ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError> where @@ -645,53 +656,71 @@ mod tests } } - #[test] - fn can_bind_to() -> Result<(), Box<dyn Error>> + #[tokio::test] + async fn can_bind_to() -> Result<(), Box<dyn Error>> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<dyn subjects::IUserManager>() - .to::<subjects::UserManager>()?; + .to::<subjects::UserManager>() + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient() -> Result<(), Box<dyn Error>> + #[tokio::test] + async fn can_bind_to_transient() -> Result<(), Box<dyn Error>> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<dyn subjects::IUserManager>() - .to::<subjects::UserManager>()? - .in_transient_scope(); + .to::<subjects::UserManager>() + .await? + .in_transient_scope() + .await; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>> + #[tokio::test] + async fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<dyn subjects::IUserManager>() - .to::<subjects::UserManager>()? + .to::<subjects::UserManager>() + .await? .in_transient_scope() - .when_named("regular")?; + .await + .when_named("regular") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -699,17 +728,22 @@ mod tests #[tokio::test] async fn can_bind_to_singleton() -> Result<(), Box<dyn Error>> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<dyn subjects::IUserManager>() - .to::<subjects::UserManager>()? + .to::<subjects::UserManager>() + .await? .in_singleton_scope() .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -717,55 +751,70 @@ mod tests #[tokio::test] async fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<dyn subjects::IUserManager>() - .to::<subjects::UserManager>()? + .to::<subjects::UserManager>() + .await? .in_singleton_scope() .await? - .when_named("cool")?; + .when_named("cool") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory() -> Result<(), Box<dyn Error>> + async fn can_bind_to_factory() -> Result<(), Box<dyn Error>> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } - di_container.bind::<IUserManagerFactory>().to_factory(&|| { - let user_manager: TransientPtr<dyn subjects::IUserManager> = - TransientPtr::new(subjects::UserManager::new()); + di_container + .bind::<IUserManagerFactory>() + .to_factory(&|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); - user_manager - })?; + user_manager + }) + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>> + async fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::<IUserManagerFactory>() @@ -774,10 +823,14 @@ mod tests TransientPtr::new(subjects::UserManager::new()); user_manager - })? - .when_named("awesome")?; + }) + .await? + .when_named("awesome") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -793,25 +846,37 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider)); + } di_container .get::<dyn subjects::IUserManager>() @@ -832,25 +897,40 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::<dyn subjects::IUserManager>( + Some("special"), + Box::new(mock_provider), + ); + } di_container .get_named::<dyn subjects::IUserManager>("special") @@ -871,13 +951,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -885,13 +967,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::<dyn subjects::INumber>(None, Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::<dyn subjects::INumber>(None, Box::new(mock_provider)); + } let first_number_rc = di_container .get::<dyn subjects::INumber>() @@ -921,13 +1015,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -935,13 +1031,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider)); + } let first_number_rc = di_container .get_named::<dyn subjects::INumber>("cool") @@ -1014,13 +1122,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -1037,9 +1147,13 @@ mod tests )) }); - di_container - .bindings - .set::<IUserManagerFactory>(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::<IUserManagerFactory>(None, Box::new(mock_provider)); + } di_container .get::<IUserManagerFactory>() diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 4aa246e..eb71ff7 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -27,18 +27,17 @@ where } } - pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Provider> + #[allow(clippy::borrowed_box)] + pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Box<Provider>> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::<Interface>(); - self.bindings - .get(&DIContainerBindingKey { - type_id: interface_typeid, - name, - }) - .map(|provider| provider.as_ref()) + self.bindings.get(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } pub fn set<Interface>(&mut self, name: Option<&'static str>, provider: Box<Provider>) diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs index badc3c5..fb5452b 100644 --- a/src/interfaces/async_injectable.rs +++ b/src/interfaces/async_injectable.rs @@ -2,6 +2,7 @@ //! //! *This module is only available if Syrette is built with the "async" feature.* use std::fmt::Debug; +use std::sync::Arc; use async_trait::async_trait; @@ -19,7 +20,7 @@ pub trait AsyncInjectable: CastFromSync /// # Errors /// Will return `Err` if resolving the dependencies fails. async fn resolve( - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<TransientPtr<Self>, InjectableError> where diff --git a/src/provider/async.rs b/src/provider/async.rs index 93ae03a..1ddb614 100644 --- a/src/provider/async.rs +++ b/src/provider/async.rs @@ -1,5 +1,6 @@ #![allow(clippy::module_name_repetitions)] use std::marker::PhantomData; +use std::sync::Arc; use async_trait::async_trait; @@ -26,9 +27,19 @@ pub trait IAsyncProvider: Send + Sync { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError>; + + fn do_clone(&self) -> Box<dyn IAsyncProvider>; +} + +impl Clone for Box<dyn IAsyncProvider> +{ + fn clone(&self) -> Self + { + self.do_clone() + } } pub struct AsyncTransientTypeProvider<InjectableType> @@ -57,7 +68,7 @@ where { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc<AsyncDIContainer>, dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError> { @@ -65,6 +76,23 @@ where InjectableType::resolve(di_container, dependency_history).await?, )) } + + fn do_clone(&self) -> Box<dyn IAsyncProvider> + { + Box::new(self.clone()) + } +} + +impl<InjectableType> Clone for AsyncTransientTypeProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + fn clone(&self) -> Self + { + Self { + injectable_phantom: self.injectable_phantom, + } + } } pub struct AsyncSingletonProvider<InjectableType> @@ -91,12 +119,29 @@ where { async fn provide( &self, - _di_container: &AsyncDIContainer, + _di_container: &Arc<AsyncDIContainer>, _dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError> { Ok(AsyncProvidable::Singleton(self.singleton.clone())) } + + fn do_clone(&self) -> Box<dyn IAsyncProvider> + { + Box::new(self.clone()) + } +} + +impl<InjectableType> Clone for AsyncSingletonProvider<InjectableType> +where + InjectableType: AsyncInjectable, +{ + fn clone(&self) -> Self + { + Self { + singleton: self.singleton.clone(), + } + } } #[cfg(feature = "factory")] @@ -126,10 +171,26 @@ impl IAsyncProvider for AsyncFactoryProvider { async fn provide( &self, - _di_container: &AsyncDIContainer, + _di_container: &Arc<AsyncDIContainer>, _dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, InjectableError> { Ok(AsyncProvidable::Factory(self.factory.clone())) } + + fn do_clone(&self) -> Box<dyn IAsyncProvider> + { + Box::new(self.clone()) + } +} + +#[cfg(feature = "factory")] +impl Clone for AsyncFactoryProvider +{ + fn clone(&self) -> Self + { + Self { + factory: self.factory.clone(), + } + } } |