aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--examples/async/bootstrap.rs11
-rw-r--r--examples/async/main.rs14
-rw-r--r--macros/src/injectable_impl.rs2
-rw-r--r--src/async_di_container.rs466
-rw-r--r--src/di_container_binding_map.rs13
-rw-r--r--src/interfaces/async_injectable.rs3
-rw-r--r--src/provider/async.rs69
8 files changed, 374 insertions, 205 deletions
diff --git a/Cargo.toml b/Cargo.toml
index e9ccb15..6b5c37d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -39,6 +39,7 @@ strum = "0.24.1"
strum_macros = "0.24.3"
paste = "1.0.8"
async-trait = { version = "0.1.57", optional = true }
+tokio = "1.20.1"
[dev_dependencies]
mockall = "0.11.1"
diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs
index b640712..7e1d2cd 100644
--- a/examples/async/bootstrap.rs
+++ b/examples/async/bootstrap.rs
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
use anyhow::Result;
use syrette::async_di_container::AsyncDIContainer;
@@ -11,18 +13,19 @@ use crate::interfaces::cat::ICat;
use crate::interfaces::dog::IDog;
use crate::interfaces::human::IHuman;
-pub async fn bootstrap() -> Result<AsyncDIContainer>
+pub async fn bootstrap() -> Result<Arc<AsyncDIContainer>>
{
let mut di_container = AsyncDIContainer::new();
di_container
.bind::<dyn IDog>()
- .to::<Dog>()?
+ .to::<Dog>()
+ .await?
.in_singleton_scope()
.await?;
- di_container.bind::<dyn ICat>().to::<Cat>()?;
- di_container.bind::<dyn IHuman>().to::<Human>()?;
+ di_container.bind::<dyn ICat>().to::<Cat>().await?;
+ di_container.bind::<dyn IHuman>().to::<Human>().await?;
Ok(di_container)
}
diff --git a/examples/async/main.rs b/examples/async/main.rs
index f72ff39..3c884fe 100644
--- a/examples/async/main.rs
+++ b/examples/async/main.rs
@@ -2,11 +2,8 @@
#![deny(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
-use std::sync::Arc;
-
use anyhow::Result;
use tokio::spawn;
-use tokio::sync::Mutex;
mod animals;
mod bootstrap;
@@ -21,12 +18,10 @@ async fn main() -> Result<()>
{
println!("Hello, world!");
- let di_container = Arc::new(Mutex::new(bootstrap().await?));
+ let di_container = bootstrap().await?;
{
let dog = di_container
- .lock()
- .await
.get::<dyn IDog>()
.await?
.threadsafe_singleton()?;
@@ -35,12 +30,7 @@ async fn main() -> Result<()>
}
spawn(async move {
- let human = di_container
- .lock()
- .await
- .get::<dyn IHuman>()
- .await?
- .transient()?;
+ let human = di_container.get::<dyn IHuman>().await?.transient()?;
human.make_pets_make_sounds();
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index af1fa68..bf5c96c 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -91,7 +91,7 @@ impl InjectableImpl
impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type
{
async fn resolve(
- #di_container_var: &syrette::async_di_container::AsyncDIContainer,
+ #di_container_var: &std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,
mut #dependency_history_var: Vec<&'static str>,
) -> Result<
syrette::ptr::TransientPtr<Self>,
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
index 7913c5a..0cd92a5 100644
--- a/src/async_di_container.rs
+++ b/src/async_di_container.rs
@@ -55,6 +55,9 @@
//! *This module is only available if Syrette is built with the "async" feature.*
use std::any::type_name;
use std::marker::PhantomData;
+use std::sync::Arc;
+
+use tokio::sync::Mutex;
#[cfg(feature = "factory")]
use crate::castable_factory::threadsafe::ThreadsafeCastableFactory;
@@ -77,19 +80,19 @@ use crate::provider::r#async::{
use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr};
/// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
-pub struct AsyncBindingWhenConfigurator<'di_container, Interface>
+pub struct AsyncBindingWhenConfigurator<Interface>
where
Interface: 'static + ?Sized,
{
- di_container: &'di_container mut AsyncDIContainer,
+ di_container: Arc<AsyncDIContainer>,
interface_phantom: PhantomData<Interface>,
}
-impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface>
+impl<Interface> AsyncBindingWhenConfigurator<Interface>
where
Interface: 'static + ?Sized,
{
- fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ fn new(di_container: Arc<AsyncDIContainer>) -> Self
{
Self {
di_container,
@@ -101,50 +104,45 @@ where
///
/// # Errors
/// Will return Err if no binding for the interface already exists.
- pub fn when_named(
- &mut self,
+ pub async fn when_named(
+ &self,
name: &'static str,
) -> Result<(), AsyncBindingWhenConfiguratorError>
{
- let binding = self
- .di_container
- .bindings
- .remove::<Interface>(None)
- .map_or_else(
- || {
- Err(AsyncBindingWhenConfiguratorError::BindingNotFound(
- type_name::<Interface>(),
- ))
- },
- Ok,
- )?;
-
- self.di_container
- .bindings
- .set::<Interface>(Some(name), binding);
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ let binding = bindings_lock.remove::<Interface>(None).map_or_else(
+ || {
+ Err(AsyncBindingWhenConfiguratorError::BindingNotFound(
+ type_name::<Interface>(),
+ ))
+ },
+ Ok,
+ )?;
+
+ bindings_lock.set::<Interface>(Some(name), binding);
Ok(())
}
}
/// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
-pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+pub struct AsyncBindingScopeConfigurator<Interface, Implementation>
where
Interface: 'static + ?Sized,
Implementation: AsyncInjectable,
{
- di_container: &'di_container mut AsyncDIContainer,
+ di_container: Arc<AsyncDIContainer>,
interface_phantom: PhantomData<Interface>,
implementation_phantom: PhantomData<Implementation>,
}
-impl<'di_container, Interface, Implementation>
- AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+impl<Interface, Implementation> AsyncBindingScopeConfigurator<Interface, Implementation>
where
Interface: 'static + ?Sized,
Implementation: AsyncInjectable,
{
- fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ fn new(di_container: Arc<AsyncDIContainer>) -> Self
{
Self {
di_container,
@@ -156,14 +154,16 @@ where
/// Configures the binding to be in a transient scope.
///
/// This is the default.
- pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator<Interface>
+ pub async fn in_transient_scope(&self) -> AsyncBindingWhenConfigurator<Interface>
{
- self.di_container.bindings.set::<Interface>(
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ bindings_lock.set::<Interface>(
None,
Box::new(AsyncTransientTypeProvider::<Implementation>::new()),
);
- AsyncBindingWhenConfigurator::new(self.di_container)
+ AsyncBindingWhenConfigurator::new(self.di_container.clone())
}
/// Configures the binding to be in a singleton scope.
@@ -171,40 +171,41 @@ where
/// # Errors
/// Will return Err if resolving the implementation fails.
pub async fn in_singleton_scope(
- &mut self,
+ &self,
) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingScopeConfiguratorError>
{
let singleton: ThreadsafeSingletonPtr<Implementation> =
ThreadsafeSingletonPtr::from(
- Implementation::resolve(self.di_container, Vec::new())
+ Implementation::resolve(&self.di_container, Vec::new())
.await
.map_err(
AsyncBindingScopeConfiguratorError::SingletonResolveFailed,
)?,
);
- self.di_container
- .bindings
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ bindings_lock
.set::<Interface>(None, Box::new(AsyncSingletonProvider::new(singleton)));
- Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone()))
}
}
/// Binding builder for type `Interface` inside a [`AsyncDIContainer`].
-pub struct AsyncBindingBuilder<'di_container, Interface>
+pub struct AsyncBindingBuilder<Interface>
where
Interface: 'static + ?Sized,
{
- di_container: &'di_container mut AsyncDIContainer,
+ di_container: Arc<AsyncDIContainer>,
interface_phantom: PhantomData<Interface>,
}
-impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface>
+impl<Interface> AsyncBindingBuilder<Interface>
where
Interface: 'static + ?Sized,
{
- fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ fn new(di_container: Arc<AsyncDIContainer>) -> Self
{
Self {
di_container,
@@ -221,8 +222,8 @@ where
/// # Errors
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
- pub fn to<Implementation>(
- &mut self,
+ pub async fn to<Implementation>(
+ &self,
) -> Result<
AsyncBindingScopeConfigurator<Interface, Implementation>,
AsyncBindingBuilderError,
@@ -230,17 +231,21 @@ where
where
Implementation: AsyncInjectable,
{
- if self.di_container.bindings.has::<Interface>(None) {
- return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
- Interface,
- >(
- )));
+ {
+ let bindings_lock = self.di_container.bindings.lock().await;
+
+ if bindings_lock.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
}
- let mut binding_scope_configurator =
- AsyncBindingScopeConfigurator::new(self.di_container);
+ let binding_scope_configurator =
+ AsyncBindingScopeConfigurator::new(self.di_container.clone());
- binding_scope_configurator.in_transient_scope();
+ binding_scope_configurator.in_transient_scope().await;
Ok(binding_scope_configurator)
}
@@ -254,8 +259,8 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub fn to_factory<Args, Return>(
- &mut self,
+ pub async fn to_factory<Args, Return>(
+ &self,
factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>
+ Send
+ Sync),
@@ -265,7 +270,9 @@ where
Return: 'static + ?Sized,
Interface: crate::interfaces::factory::IFactory<Args, Return>,
{
- if self.di_container.bindings.has::<Interface>(None) {
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ if bindings_lock.has::<Interface>(None) {
return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
Interface,
>(
@@ -274,14 +281,14 @@ where
let factory_impl = ThreadsafeCastableFactory::new(factory_func);
- self.di_container.bindings.set::<Interface>(
+ bindings_lock.set::<Interface>(
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
)),
);
- Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone()))
}
/// Creates a binding of type `Interface` to a factory that takes no arguments
@@ -293,8 +300,8 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub fn to_default_factory<Return>(
- &mut self,
+ pub async fn to_default_factory<Return>(
+ &self,
factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>
+ Send
+ Sync),
@@ -302,7 +309,9 @@ where
where
Return: 'static + ?Sized,
{
- if self.di_container.bindings.has::<Interface>(None) {
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ if bindings_lock.has::<Interface>(None) {
return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
Interface,
>(
@@ -311,40 +320,40 @@ where
let factory_impl = ThreadsafeCastableFactory::new(factory_func);
- self.di_container.bindings.set::<Interface>(
+ bindings_lock.set::<Interface>(
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
)),
);
- Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone()))
}
}
/// Dependency injection container.
pub struct AsyncDIContainer
{
- bindings: DIContainerBindingMap<dyn IAsyncProvider>,
+ bindings: Mutex<DIContainerBindingMap<dyn IAsyncProvider>>,
}
impl AsyncDIContainer
{
/// Returns a new `AsyncDIContainer`.
#[must_use]
- pub fn new() -> Self
+ pub fn new() -> Arc<Self>
{
- Self {
- bindings: DIContainerBindingMap::new(),
- }
+ Arc::new(Self {
+ bindings: Mutex::new(DIContainerBindingMap::new()),
+ })
}
/// Returns a new [`AsyncBindingBuilder`] for the given interface.
- pub fn bind<Interface>(&mut self) -> AsyncBindingBuilder<Interface>
+ pub fn bind<Interface>(self: &mut Arc<Self>) -> AsyncBindingBuilder<Interface>
where
Interface: 'static + ?Sized,
{
- AsyncBindingBuilder::<Interface>::new(self)
+ AsyncBindingBuilder::<Interface>::new(self.clone())
}
/// Returns the type bound with `Interface`.
@@ -355,7 +364,7 @@ impl AsyncDIContainer
/// - Resolving the binding for fails
/// - Casting the binding for fails
pub async fn get<Interface>(
- &self,
+ self: &Arc<Self>,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
Interface: 'static + ?Sized,
@@ -371,7 +380,7 @@ impl AsyncDIContainer
/// - Resolving the binding fails
/// - Casting the binding for fails
pub async fn get_named<Interface>(
- &self,
+ self: &Arc<Self>,
name: &'static str,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
@@ -382,7 +391,7 @@ impl AsyncDIContainer
#[doc(hidden)]
pub async fn get_bound<Interface>(
- &self,
+ self: &Arc<Self>,
dependency_history: Vec<&'static str>,
name: Option<&'static str>,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
@@ -471,24 +480,33 @@ impl AsyncDIContainer
}
async fn get_binding_providable<Interface>(
- &self,
+ self: &Arc<Self>,
name: Option<&'static str>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, AsyncDIContainerError>
where
Interface: 'static + ?Sized,
{
- self.bindings
- .get::<Interface>(name)
- .map_or_else(
- || {
- Err(AsyncDIContainerError::BindingNotFound {
- interface: type_name::<Interface>(),
- name,
- })
- },
- Ok,
- )?
+ let provider;
+
+ {
+ let bindings_lock = self.bindings.lock().await;
+
+ provider = bindings_lock
+ .get::<Interface>(name)
+ .map_or_else(
+ || {
+ Err(AsyncDIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })
+ },
+ Ok,
+ )?
+ .clone();
+ }
+
+ provider
.provide(self, dependency_history)
.await
.map_err(|err| AsyncDIContainerError::BindingResolveFailed {
@@ -498,14 +516,6 @@ impl AsyncDIContainer
}
}
-impl Default for AsyncDIContainer
-{
- fn default() -> Self
- {
- Self::new()
- }
-}
-
#[cfg(test)]
mod tests
{
@@ -523,6 +533,7 @@ mod tests
//! Test subjects.
use std::fmt::Debug;
+ use std::sync::Arc;
use async_trait::async_trait;
use syrette_macros::declare_interface;
@@ -569,7 +580,7 @@ mod tests
impl AsyncInjectable for UserManager
{
async fn resolve(
- _: &AsyncDIContainer,
+ _: &Arc<AsyncDIContainer>,
_dependency_history: Vec<&'static str>,
) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
where
@@ -634,7 +645,7 @@ mod tests
impl AsyncInjectable for Number
{
async fn resolve(
- _: &AsyncDIContainer,
+ _: &Arc<AsyncDIContainer>,
_dependency_history: Vec<&'static str>,
) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
where
@@ -645,53 +656,71 @@ mod tests
}
}
- #[test]
- fn can_bind_to() -> Result<(), Box<dyn Error>>
+ #[tokio::test]
+ async fn can_bind_to() -> Result<(), Box<dyn Error>>
{
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<dyn subjects::IUserManager>()
- .to::<subjects::UserManager>()?;
+ .to::<subjects::UserManager>()
+ .await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
- #[test]
- fn can_bind_to_transient() -> Result<(), Box<dyn Error>>
+ #[tokio::test]
+ async fn can_bind_to_transient() -> Result<(), Box<dyn Error>>
{
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<dyn subjects::IUserManager>()
- .to::<subjects::UserManager>()?
- .in_transient_scope();
+ .to::<subjects::UserManager>()
+ .await?
+ .in_transient_scope()
+ .await;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
- #[test]
- fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>>
+ #[tokio::test]
+ async fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>>
{
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<dyn subjects::IUserManager>()
- .to::<subjects::UserManager>()?
+ .to::<subjects::UserManager>()
+ .await?
.in_transient_scope()
- .when_named("regular")?;
+ .await
+ .when_named("regular")
+ .await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
@@ -699,17 +728,22 @@ mod tests
#[tokio::test]
async fn can_bind_to_singleton() -> Result<(), Box<dyn Error>>
{
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<dyn subjects::IUserManager>()
- .to::<subjects::UserManager>()?
+ .to::<subjects::UserManager>()
+ .await?
.in_singleton_scope()
.await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
@@ -717,55 +751,70 @@ mod tests
#[tokio::test]
async fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>>
{
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<dyn subjects::IUserManager>()
- .to::<subjects::UserManager>()?
+ .to::<subjects::UserManager>()
+ .await?
.in_singleton_scope()
.await?
- .when_named("cool")?;
+ .when_named("cool")
+ .await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
- #[test]
+ #[tokio::test]
#[cfg(feature = "factory")]
- fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
+ async fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
{
type IUserManagerFactory =
dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
- di_container.bind::<IUserManagerFactory>().to_factory(&|| {
- let user_manager: TransientPtr<dyn subjects::IUserManager> =
- TransientPtr::new(subjects::UserManager::new());
+ di_container
+ .bind::<IUserManagerFactory>()
+ .to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
- user_manager
- })?;
+ user_manager
+ })
+ .await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
- #[test]
+ #[tokio::test]
#[cfg(feature = "factory")]
- fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
+ async fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
{
type IUserManagerFactory =
dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
- assert_eq!(di_container.bindings.count(), 0);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 0);
+ }
di_container
.bind::<IUserManagerFactory>()
@@ -774,10 +823,14 @@ mod tests
TransientPtr::new(subjects::UserManager::new());
user_manager
- })?
- .when_named("awesome")?;
+ })
+ .await?
+ .when_named("awesome")
+ .await?;
- assert_eq!(di_container.bindings.count(), 1);
+ {
+ assert_eq!(di_container.bindings.lock().await.count(), 1);
+ }
Ok(())
}
@@ -793,25 +846,37 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Transient(TransientPtr::new(
- subjects::UserManager::new(),
- )))
+ mock_provider.expect_do_clone().returning(|| {
+ let mut inner_mock_provider = MockProvider::new();
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ Box::new(inner_mock_provider)
});
- di_container
- .bindings
- .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider));
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider));
+ }
di_container
.get::<dyn subjects::IUserManager>()
@@ -832,25 +897,40 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Transient(TransientPtr::new(
- subjects::UserManager::new(),
- )))
+ mock_provider.expect_do_clone().returning(|| {
+ let mut inner_mock_provider = MockProvider::new();
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ Box::new(inner_mock_provider)
});
- di_container
- .bindings
- .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider));
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<dyn subjects::IUserManager>(
+ Some("special"),
+ Box::new(mock_provider),
+ );
+ }
di_container
.get_named::<dyn subjects::IUserManager>("special")
@@ -871,13 +951,15 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
@@ -885,13 +967,25 @@ mod tests
ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
- mock_provider
- .expect_provide()
- .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+ mock_provider.expect_do_clone().returning(move || {
+ let mut inner_mock_provider = MockProvider::new();
- di_container
- .bindings
- .set::<dyn subjects::INumber>(None, Box::new(mock_provider));
+ let singleton_clone = singleton.clone();
+
+ inner_mock_provider.expect_provide().returning(move |_, _| {
+ Ok(AsyncProvidable::Singleton(singleton_clone.clone()))
+ });
+
+ Box::new(inner_mock_provider)
+ });
+
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<dyn subjects::INumber>(None, Box::new(mock_provider));
+ }
let first_number_rc = di_container
.get::<dyn subjects::INumber>()
@@ -921,13 +1015,15 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
@@ -935,13 +1031,25 @@ mod tests
ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
- mock_provider
- .expect_provide()
- .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+ mock_provider.expect_do_clone().returning(move || {
+ let mut inner_mock_provider = MockProvider::new();
- di_container
- .bindings
- .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider));
+ let singleton_clone = singleton.clone();
+
+ inner_mock_provider.expect_provide().returning(move |_, _| {
+ Ok(AsyncProvidable::Singleton(singleton_clone.clone()))
+ });
+
+ Box::new(inner_mock_provider)
+ });
+
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider));
+ }
let first_number_rc = di_container
.get_named::<dyn subjects::INumber>("cool")
@@ -1014,13 +1122,15 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let mut di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
@@ -1037,9 +1147,13 @@ mod tests
))
});
- di_container
- .bindings
- .set::<IUserManagerFactory>(None, Box::new(mock_provider));
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<IUserManagerFactory>(None, Box::new(mock_provider));
+ }
di_container
.get::<IUserManagerFactory>()
diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs
index 4aa246e..eb71ff7 100644
--- a/src/di_container_binding_map.rs
+++ b/src/di_container_binding_map.rs
@@ -27,18 +27,17 @@ where
}
}
- pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Provider>
+ #[allow(clippy::borrowed_box)]
+ pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Box<Provider>>
where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
- self.bindings
- .get(&DIContainerBindingKey {
- type_id: interface_typeid,
- name,
- })
- .map(|provider| provider.as_ref())
+ self.bindings.get(&DIContainerBindingKey {
+ type_id: interface_typeid,
+ name,
+ })
}
pub fn set<Interface>(&mut self, name: Option<&'static str>, provider: Box<Provider>)
diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs
index badc3c5..fb5452b 100644
--- a/src/interfaces/async_injectable.rs
+++ b/src/interfaces/async_injectable.rs
@@ -2,6 +2,7 @@
//!
//! *This module is only available if Syrette is built with the "async" feature.*
use std::fmt::Debug;
+use std::sync::Arc;
use async_trait::async_trait;
@@ -19,7 +20,7 @@ pub trait AsyncInjectable: CastFromSync
/// # Errors
/// Will return `Err` if resolving the dependencies fails.
async fn resolve(
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<TransientPtr<Self>, InjectableError>
where
diff --git a/src/provider/async.rs b/src/provider/async.rs
index 93ae03a..1ddb614 100644
--- a/src/provider/async.rs
+++ b/src/provider/async.rs
@@ -1,5 +1,6 @@
#![allow(clippy::module_name_repetitions)]
use std::marker::PhantomData;
+use std::sync::Arc;
use async_trait::async_trait;
@@ -26,9 +27,19 @@ pub trait IAsyncProvider: Send + Sync
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
+}
+
+impl Clone for Box<dyn IAsyncProvider>
+{
+ fn clone(&self) -> Self
+ {
+ self.do_clone()
+ }
}
pub struct AsyncTransientTypeProvider<InjectableType>
@@ -57,7 +68,7 @@ where
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
@@ -65,6 +76,23 @@ where
InjectableType::resolve(di_container, dependency_history).await?,
))
}
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>
+ {
+ Box::new(self.clone())
+ }
+}
+
+impl<InjectableType> Clone for AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ fn clone(&self) -> Self
+ {
+ Self {
+ injectable_phantom: self.injectable_phantom,
+ }
+ }
}
pub struct AsyncSingletonProvider<InjectableType>
@@ -91,12 +119,29 @@ where
{
async fn provide(
&self,
- _di_container: &AsyncDIContainer,
+ _di_container: &Arc<AsyncDIContainer>,
_dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
Ok(AsyncProvidable::Singleton(self.singleton.clone()))
}
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>
+ {
+ Box::new(self.clone())
+ }
+}
+
+impl<InjectableType> Clone for AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ fn clone(&self) -> Self
+ {
+ Self {
+ singleton: self.singleton.clone(),
+ }
+ }
}
#[cfg(feature = "factory")]
@@ -126,10 +171,26 @@ impl IAsyncProvider for AsyncFactoryProvider
{
async fn provide(
&self,
- _di_container: &AsyncDIContainer,
+ _di_container: &Arc<AsyncDIContainer>,
_dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
Ok(AsyncProvidable::Factory(self.factory.clone()))
}
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>
+ {
+ Box::new(self.clone())
+ }
+}
+
+#[cfg(feature = "factory")]
+impl Clone for AsyncFactoryProvider
+{
+ fn clone(&self) -> Self
+ {
+ Self {
+ factory: self.factory.clone(),
+ }
+ }
}