From 8651f84f205da7a89f2fc7333d1dd8de0d80a22b Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 17 Sep 2022 16:12:45 +0200 Subject: refactor!: make async DI container be used inside of a Arc BREAKING CHANGE: The async DI container is to be used inside of a Arc & it also no longer implements Default --- Cargo.toml | 1 + examples/async/bootstrap.rs | 11 +- examples/async/main.rs | 14 +- macros/src/injectable_impl.rs | 2 +- src/async_di_container.rs | 466 +++++++++++++++++++++++-------------- src/di_container_binding_map.rs | 13 +- src/interfaces/async_injectable.rs | 3 +- src/provider/async.rs | 69 +++++- 8 files changed, 374 insertions(+), 205 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e9ccb15..6b5c37d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ strum = "0.24.1" strum_macros = "0.24.3" paste = "1.0.8" async-trait = { version = "0.1.57", optional = true } +tokio = "1.20.1" [dev_dependencies] mockall = "0.11.1" diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs index b640712..7e1d2cd 100644 --- a/examples/async/bootstrap.rs +++ b/examples/async/bootstrap.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::Result; use syrette::async_di_container::AsyncDIContainer; @@ -11,18 +13,19 @@ use crate::interfaces::cat::ICat; use crate::interfaces::dog::IDog; use crate::interfaces::human::IHuman; -pub async fn bootstrap() -> Result +pub async fn bootstrap() -> Result> { let mut di_container = AsyncDIContainer::new(); di_container .bind::() - .to::()? + .to::() + .await? .in_singleton_scope() .await?; - di_container.bind::().to::()?; - di_container.bind::().to::()?; + di_container.bind::().to::().await?; + di_container.bind::().to::().await?; Ok(di_container) } diff --git a/examples/async/main.rs b/examples/async/main.rs index f72ff39..3c884fe 100644 --- a/examples/async/main.rs +++ b/examples/async/main.rs @@ -2,11 +2,8 @@ #![deny(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] -use std::sync::Arc; - use anyhow::Result; use tokio::spawn; -use tokio::sync::Mutex; mod animals; mod bootstrap; @@ -21,12 +18,10 @@ async fn main() -> Result<()> { println!("Hello, world!"); - let di_container = Arc::new(Mutex::new(bootstrap().await?)); + let di_container = bootstrap().await?; { let dog = di_container - .lock() - .await .get::() .await? .threadsafe_singleton()?; @@ -35,12 +30,7 @@ async fn main() -> Result<()> } spawn(async move { - let human = di_container - .lock() - .await - .get::() - .await? - .transient()?; + let human = di_container.get::().await?.transient()?; human.make_pets_make_sounds(); diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index af1fa68..bf5c96c 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -91,7 +91,7 @@ impl InjectableImpl impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type { async fn resolve( - #di_container_var: &syrette::async_di_container::AsyncDIContainer, + #di_container_var: &std::sync::Arc, mut #dependency_history_var: Vec<&'static str>, ) -> Result< syrette::ptr::TransientPtr, diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 7913c5a..0cd92a5 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -55,6 +55,9 @@ //! *This module is only available if Syrette is built with the "async" feature.* use std::any::type_name; use std::marker::PhantomData; +use std::sync::Arc; + +use tokio::sync::Mutex; #[cfg(feature = "factory")] use crate::castable_factory::threadsafe::ThreadsafeCastableFactory; @@ -77,19 +80,19 @@ use crate::provider::r#async::{ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; /// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingWhenConfigurator<'di_container, Interface> +pub struct AsyncBindingWhenConfigurator where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, } -impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface> +impl AsyncBindingWhenConfigurator where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -101,50 +104,45 @@ where /// /// # Errors /// Will return Err if no binding for the interface already exists. - pub fn when_named( - &mut self, + pub async fn when_named( + &self, name: &'static str, ) -> Result<(), AsyncBindingWhenConfiguratorError> { - let binding = self - .di_container - .bindings - .remove::(None) - .map_or_else( - || { - Err(AsyncBindingWhenConfiguratorError::BindingNotFound( - type_name::(), - )) - }, - Ok, - )?; - - self.di_container - .bindings - .set::(Some(name), binding); + let mut bindings_lock = self.di_container.bindings.lock().await; + + let binding = bindings_lock.remove::(None).map_or_else( + || { + Err(AsyncBindingWhenConfiguratorError::BindingNotFound( + type_name::(), + )) + }, + Ok, + )?; + + bindings_lock.set::(Some(name), binding); Ok(()) } } /// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. -pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +pub struct AsyncBindingScopeConfigurator where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, implementation_phantom: PhantomData, } -impl<'di_container, Interface, Implementation> - AsyncBindingScopeConfigurator<'di_container, Interface, Implementation> +impl AsyncBindingScopeConfigurator where Interface: 'static + ?Sized, Implementation: AsyncInjectable, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -156,14 +154,16 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator + pub async fn in_transient_scope(&self) -> AsyncBindingWhenConfigurator { - self.di_container.bindings.set::( + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock.set::( None, Box::new(AsyncTransientTypeProvider::::new()), ); - AsyncBindingWhenConfigurator::new(self.di_container) + AsyncBindingWhenConfigurator::new(self.di_container.clone()) } /// Configures the binding to be in a singleton scope. @@ -171,40 +171,41 @@ where /// # Errors /// Will return Err if resolving the implementation fails. pub async fn in_singleton_scope( - &mut self, + &self, ) -> Result, AsyncBindingScopeConfiguratorError> { let singleton: ThreadsafeSingletonPtr = ThreadsafeSingletonPtr::from( - Implementation::resolve(self.di_container, Vec::new()) + Implementation::resolve(&self.di_container, Vec::new()) .await .map_err( AsyncBindingScopeConfiguratorError::SingletonResolveFailed, )?, ); - self.di_container - .bindings + let mut bindings_lock = self.di_container.bindings.lock().await; + + bindings_lock .set::(None, Box::new(AsyncSingletonProvider::new(singleton))); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. -pub struct AsyncBindingBuilder<'di_container, Interface> +pub struct AsyncBindingBuilder where Interface: 'static + ?Sized, { - di_container: &'di_container mut AsyncDIContainer, + di_container: Arc, interface_phantom: PhantomData, } -impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface> +impl AsyncBindingBuilder where Interface: 'static + ?Sized, { - fn new(di_container: &'di_container mut AsyncDIContainer) -> Self + fn new(di_container: Arc) -> Self { Self { di_container, @@ -221,8 +222,8 @@ where /// # Errors /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. - pub fn to( - &mut self, + pub async fn to( + &self, ) -> Result< AsyncBindingScopeConfigurator, AsyncBindingBuilderError, @@ -230,17 +231,21 @@ where where Implementation: AsyncInjectable, { - if self.di_container.bindings.has::(None) { - return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< - Interface, - >( - ))); + { + let bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { + return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< + Interface, + >( + ))); + } } - let mut binding_scope_configurator = - AsyncBindingScopeConfigurator::new(self.di_container); + let binding_scope_configurator = + AsyncBindingScopeConfigurator::new(self.di_container.clone()); - binding_scope_configurator.in_transient_scope(); + binding_scope_configurator.in_transient_scope().await; Ok(binding_scope_configurator) } @@ -254,8 +259,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_factory( - &mut self, + pub async fn to_factory( + &self, factory_func: &'static (dyn Fn> + Send + Sync), @@ -265,7 +270,9 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory, { - if self.di_container.bindings.has::(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -274,14 +281,14 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::( + bindings_lock.set::( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -293,8 +300,8 @@ where /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for /// the interface. #[cfg(feature = "factory")] - pub fn to_default_factory( - &mut self, + pub async fn to_default_factory( + &self, factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr> + Send + Sync), @@ -302,7 +309,9 @@ where where Return: 'static + ?Sized, { - if self.di_container.bindings.has::(None) { + let mut bindings_lock = self.di_container.bindings.lock().await; + + if bindings_lock.has::(None) { return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::< Interface, >( @@ -311,40 +320,40 @@ where let factory_impl = ThreadsafeCastableFactory::new(factory_func); - self.di_container.bindings.set::( + bindings_lock.set::( None, Box::new(crate::provider::r#async::AsyncFactoryProvider::new( crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), )), ); - Ok(AsyncBindingWhenConfigurator::new(self.di_container)) + Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone())) } } /// Dependency injection container. pub struct AsyncDIContainer { - bindings: DIContainerBindingMap, + bindings: Mutex>, } impl AsyncDIContainer { /// Returns a new `AsyncDIContainer`. #[must_use] - pub fn new() -> Self + pub fn new() -> Arc { - Self { - bindings: DIContainerBindingMap::new(), - } + Arc::new(Self { + bindings: Mutex::new(DIContainerBindingMap::new()), + }) } /// Returns a new [`AsyncBindingBuilder`] for the given interface. - pub fn bind(&mut self) -> AsyncBindingBuilder + pub fn bind(self: &mut Arc) -> AsyncBindingBuilder where Interface: 'static + ?Sized, { - AsyncBindingBuilder::::new(self) + AsyncBindingBuilder::::new(self.clone()) } /// Returns the type bound with `Interface`. @@ -355,7 +364,7 @@ impl AsyncDIContainer /// - Resolving the binding for fails /// - Casting the binding for fails pub async fn get( - &self, + self: &Arc, ) -> Result, AsyncDIContainerError> where Interface: 'static + ?Sized, @@ -371,7 +380,7 @@ impl AsyncDIContainer /// - Resolving the binding fails /// - Casting the binding for fails pub async fn get_named( - &self, + self: &Arc, name: &'static str, ) -> Result, AsyncDIContainerError> where @@ -382,7 +391,7 @@ impl AsyncDIContainer #[doc(hidden)] pub async fn get_bound( - &self, + self: &Arc, dependency_history: Vec<&'static str>, name: Option<&'static str>, ) -> Result, AsyncDIContainerError> @@ -471,24 +480,33 @@ impl AsyncDIContainer } async fn get_binding_providable( - &self, + self: &Arc, name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result where Interface: 'static + ?Sized, { - self.bindings - .get::(name) - .map_or_else( - || { - Err(AsyncDIContainerError::BindingNotFound { - interface: type_name::(), - name, - }) - }, - Ok, - )? + let provider; + + { + let bindings_lock = self.bindings.lock().await; + + provider = bindings_lock + .get::(name) + .map_or_else( + || { + Err(AsyncDIContainerError::BindingNotFound { + interface: type_name::(), + name, + }) + }, + Ok, + )? + .clone(); + } + + provider .provide(self, dependency_history) .await .map_err(|err| AsyncDIContainerError::BindingResolveFailed { @@ -498,14 +516,6 @@ impl AsyncDIContainer } } -impl Default for AsyncDIContainer -{ - fn default() -> Self - { - Self::new() - } -} - #[cfg(test)] mod tests { @@ -523,6 +533,7 @@ mod tests //! Test subjects. use std::fmt::Debug; + use std::sync::Arc; use async_trait::async_trait; use syrette_macros::declare_interface; @@ -569,7 +580,7 @@ mod tests impl AsyncInjectable for UserManager { async fn resolve( - _: &AsyncDIContainer, + _: &Arc, _dependency_history: Vec<&'static str>, ) -> Result, crate::errors::injectable::InjectableError> where @@ -634,7 +645,7 @@ mod tests impl AsyncInjectable for Number { async fn resolve( - _: &AsyncDIContainer, + _: &Arc, _dependency_history: Vec<&'static str>, ) -> Result, crate::errors::injectable::InjectableError> where @@ -645,53 +656,71 @@ mod tests } } - #[test] - fn can_bind_to() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()?; + .to::() + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to_transient() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? - .in_transient_scope(); + .to::() + .await? + .in_transient_scope() + .await; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] - fn can_bind_to_transient_when_named() -> Result<(), Box> + #[tokio::test] + async fn can_bind_to_transient_when_named() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_transient_scope() - .when_named("regular")?; + .await + .when_named("regular") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -699,17 +728,22 @@ mod tests #[tokio::test] async fn can_bind_to_singleton() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_singleton_scope() .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -717,55 +751,70 @@ mod tests #[tokio::test] async fn can_bind_to_singleton_when_named() -> Result<(), Box> { - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() - .to::()? + .to::() + .await? .in_singleton_scope() .await? - .when_named("cool")?; + .when_named("cool") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory() -> Result<(), Box> + async fn can_bind_to_factory() -> Result<(), Box> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } - di_container.bind::().to_factory(&|| { - let user_manager: TransientPtr = - TransientPtr::new(subjects::UserManager::new()); + di_container + .bind::() + .to_factory(&|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); - user_manager - })?; + user_manager + }) + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } - #[test] + #[tokio::test] #[cfg(feature = "factory")] - fn can_bind_to_factory_when_named() -> Result<(), Box> + async fn can_bind_to_factory_when_named() -> Result<(), Box> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); - assert_eq!(di_container.bindings.count(), 0); + { + assert_eq!(di_container.bindings.lock().await.count(), 0); + } di_container .bind::() @@ -774,10 +823,14 @@ mod tests TransientPtr::new(subjects::UserManager::new()); user_manager - })? - .when_named("awesome")?; + }) + .await? + .when_named("awesome") + .await?; - assert_eq!(di_container.bindings.count(), 1); + { + assert_eq!(di_container.bindings.lock().await.count(), 1); + } Ok(()) } @@ -793,25 +846,37 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } di_container .get::() @@ -832,25 +897,40 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Transient(TransientPtr::new( - subjects::UserManager::new(), - ))) + mock_provider.expect_do_clone().returning(|| { + let mut inner_mock_provider = MockProvider::new(); + + inner_mock_provider.expect_provide().returning(|_, _| { + Ok(AsyncProvidable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + Box::new(inner_mock_provider) }); - di_container - .bindings - .set::(Some("special"), Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::( + Some("special"), + Box::new(mock_provider), + ); + } di_container .get_named::("special") @@ -871,13 +951,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -885,13 +967,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } let first_number_rc = di_container .get::() @@ -921,13 +1015,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -935,13 +1031,25 @@ mod tests ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; - mock_provider - .expect_provide() - .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone()))); + mock_provider.expect_do_clone().returning(move || { + let mut inner_mock_provider = MockProvider::new(); - di_container - .bindings - .set::(Some("cool"), Box::new(mock_provider)); + let singleton_clone = singleton.clone(); + + inner_mock_provider.expect_provide().returning(move |_, _| { + Ok(AsyncProvidable::Singleton(singleton_clone.clone())) + }); + + Box::new(inner_mock_provider) + }); + + { + di_container + .bindings + .lock() + .await + .set::(Some("cool"), Box::new(mock_provider)); + } let first_number_rc = di_container .get_named::("cool") @@ -1014,13 +1122,15 @@ mod tests { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; } } - let mut di_container: AsyncDIContainer = AsyncDIContainer::new(); + let mut di_container = AsyncDIContainer::new(); let mut mock_provider = MockProvider::new(); @@ -1037,9 +1147,13 @@ mod tests )) }); - di_container - .bindings - .set::(None, Box::new(mock_provider)); + { + di_container + .bindings + .lock() + .await + .set::(None, Box::new(mock_provider)); + } di_container .get::() diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 4aa246e..eb71ff7 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -27,18 +27,17 @@ where } } - pub fn get(&self, name: Option<&'static str>) -> Option<&Provider> + #[allow(clippy::borrowed_box)] + pub fn get(&self, name: Option<&'static str>) -> Option<&Box> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::(); - self.bindings - .get(&DIContainerBindingKey { - type_id: interface_typeid, - name, - }) - .map(|provider| provider.as_ref()) + self.bindings.get(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } pub fn set(&mut self, name: Option<&'static str>, provider: Box) diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs index badc3c5..fb5452b 100644 --- a/src/interfaces/async_injectable.rs +++ b/src/interfaces/async_injectable.rs @@ -2,6 +2,7 @@ //! //! *This module is only available if Syrette is built with the "async" feature.* use std::fmt::Debug; +use std::sync::Arc; use async_trait::async_trait; @@ -19,7 +20,7 @@ pub trait AsyncInjectable: CastFromSync /// # Errors /// Will return `Err` if resolving the dependencies fails. async fn resolve( - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result, InjectableError> where diff --git a/src/provider/async.rs b/src/provider/async.rs index 93ae03a..1ddb614 100644 --- a/src/provider/async.rs +++ b/src/provider/async.rs @@ -1,5 +1,6 @@ #![allow(clippy::module_name_repetitions)] use std::marker::PhantomData; +use std::sync::Arc; use async_trait::async_trait; @@ -26,9 +27,19 @@ pub trait IAsyncProvider: Send + Sync { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result; + + fn do_clone(&self) -> Box; +} + +impl Clone for Box +{ + fn clone(&self) -> Self + { + self.do_clone() + } } pub struct AsyncTransientTypeProvider @@ -57,7 +68,7 @@ where { async fn provide( &self, - di_container: &AsyncDIContainer, + di_container: &Arc, dependency_history: Vec<&'static str>, ) -> Result { @@ -65,6 +76,23 @@ where InjectableType::resolve(di_container, dependency_history).await?, )) } + + fn do_clone(&self) -> Box + { + Box::new(self.clone()) + } +} + +impl Clone for AsyncTransientTypeProvider +where + InjectableType: AsyncInjectable, +{ + fn clone(&self) -> Self + { + Self { + injectable_phantom: self.injectable_phantom, + } + } } pub struct AsyncSingletonProvider @@ -91,12 +119,29 @@ where { async fn provide( &self, - _di_container: &AsyncDIContainer, + _di_container: &Arc, _dependency_history: Vec<&'static str>, ) -> Result { Ok(AsyncProvidable::Singleton(self.singleton.clone())) } + + fn do_clone(&self) -> Box + { + Box::new(self.clone()) + } +} + +impl Clone for AsyncSingletonProvider +where + InjectableType: AsyncInjectable, +{ + fn clone(&self) -> Self + { + Self { + singleton: self.singleton.clone(), + } + } } #[cfg(feature = "factory")] @@ -126,10 +171,26 @@ impl IAsyncProvider for AsyncFactoryProvider { async fn provide( &self, - _di_container: &AsyncDIContainer, + _di_container: &Arc, _dependency_history: Vec<&'static str>, ) -> Result { Ok(AsyncProvidable::Factory(self.factory.clone())) } + + fn do_clone(&self) -> Box + { + Box::new(self.clone()) + } +} + +#[cfg(feature = "factory")] +impl Clone for AsyncFactoryProvider +{ + fn clone(&self) -> Self + { + Self { + factory: self.factory.clone(), + } + } } -- cgit v1.2.3-18-g5258