aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-09-17 18:33:43 +0200
committerHampusM <hampus@hampusmat.com>2022-09-17 18:33:43 +0200
commit7de7f73963a266cceff85d6ab71c3256e5d382ec (patch)
tree67575870945b7ed0a5eeb99ccba79327598b3e02 /src
parent8651f84f205da7a89f2fc7333d1dd8de0d80a22b (diff)
feat!: allow factories to access async DI container
BREAKING CHANGE: The to_factory & to_default_factory methods of AsyncBindingBuilder now expects a function returning a factory function
Diffstat (limited to 'src')
-rw-r--r--src/async_di_container.rs210
-rw-r--r--src/provider/async.rs20
2 files changed, 152 insertions, 78 deletions
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
index 0cd92a5..ef0a540 100644
--- a/src/async_di_container.rs
+++ b/src/async_di_container.rs
@@ -39,7 +39,8 @@
//!
//! di_container
//! .bind::<dyn IDatabaseService>()
-//! .to::<DatabaseService>()?;
+//! .to::<DatabaseService>()
+//! .await?;
//!
//! let database_service = di_container
//! .get::<dyn IDatabaseService>()
@@ -259,16 +260,17 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub async fn to_factory<Args, Return>(
+ pub async fn to_factory<Args, Return, FactoryFunc>(
&self,
- factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>
- + Send
- + Sync),
+ factory_func: &'static FactoryFunc,
) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
where
Args: 'static,
Return: 'static + ?Sized,
- Interface: crate::interfaces::factory::IFactory<Args, Return>,
+ Interface: Fn<Args, Output = Return>,
+ FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = Box<(dyn Fn<Args, Output = Return>)>>
+ + Send
+ + Sync,
{
let mut bindings_lock = self.di_container.bindings.lock().await;
@@ -285,6 +287,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ false,
)),
);
@@ -300,14 +303,15 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub async fn to_default_factory<Return>(
+ pub async fn to_default_factory<Return, FactoryFunc>(
&self,
- factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>
- + Send
- + Sync),
+ factory_func: &'static FactoryFunc,
) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
where
Return: 'static + ?Sized,
+ FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync,
{
let mut bindings_lock = self.di_container.bindings.lock().await;
@@ -324,6 +328,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ true,
)),
);
@@ -402,10 +407,11 @@ impl AsyncDIContainer
.get_binding_providable::<Interface>(name, dependency_history)
.await?;
- Self::handle_binding_providable(binding_providable)
+ self.handle_binding_providable(binding_providable)
}
fn handle_binding_providable<Interface>(
+ self: &Arc<Self>,
binding_providable: AsyncProvidable,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
@@ -444,37 +450,49 @@ impl AsyncDIContainer
}
#[cfg(feature = "factory")]
AsyncProvidable::Factory(factory_binding) => {
- match factory_binding.clone().cast::<Interface>() {
- Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)),
- Err(first_err) => {
- use crate::interfaces::factory::IFactory;
-
- if let CastError::NotArcCastable(_) = first_err {
- return Err(AsyncDIContainerError::InterfaceNotAsync(
+ let factory = factory_binding
+ .cast::<dyn crate::interfaces::factory::IFactory<
+ (Arc<AsyncDIContainer>,),
+ Interface,
+ >>()
+ .map_err(|err| match err {
+ CastError::NotArcCastable(_) => {
+ AsyncDIContainerError::InterfaceNotAsync(
type_name::<Interface>(),
- ));
+ )
+ }
+ CastError::CastFailed { from: _, to: _ } => {
+ AsyncDIContainerError::CastFailed {
+ interface: type_name::<Interface>(),
+ binding_kind: "factory",
+ }
}
+ })?;
- let default_factory = factory_binding
- .cast::<dyn IFactory<(), Interface>>()
- .map_err(|err| match err {
- CastError::NotArcCastable(_) => {
- AsyncDIContainerError::InterfaceNotAsync(type_name::<
- Interface,
- >(
- ))
- }
- CastError::CastFailed { from: _, to: _ } => {
- AsyncDIContainerError::CastFailed {
- interface: type_name::<Interface>(),
- binding_kind: "factory",
- }
- }
- })?;
+ Ok(SomeThreadsafePtr::ThreadsafeFactory(
+ factory(self.clone()).into(),
+ ))
+ }
+ AsyncProvidable::DefaultFactory(default_factory_binding) => {
+ use crate::interfaces::factory::IFactory;
+
+ let default_factory = default_factory_binding
+ .cast::<dyn IFactory<(Arc<AsyncDIContainer>,), Interface>>()
+ .map_err(|err| match err {
+ CastError::NotArcCastable(_) => {
+ AsyncDIContainerError::InterfaceNotAsync(
+ type_name::<Interface>(),
+ )
+ }
+ CastError::CastFailed { from: _, to: _ } => {
+ AsyncDIContainerError::CastFailed {
+ interface: type_name::<Interface>(),
+ binding_kind: "default factory",
+ }
+ }
+ })?;
- Ok(SomeThreadsafePtr::Transient(default_factory()))
- }
- }
+ Ok(SomeThreadsafePtr::Transient(default_factory(self.clone())))
}
}
}
@@ -788,11 +806,13 @@ mod tests
di_container
.bind::<IUserManagerFactory>()
- .to_factory(&|| {
- let user_manager: TransientPtr<dyn subjects::IUserManager> =
- TransientPtr::new(subjects::UserManager::new());
+ .to_factory(&|_| {
+ Box::new(|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
- user_manager
+ user_manager
+ })
})
.await?;
@@ -818,11 +838,13 @@ mod tests
di_container
.bind::<IUserManagerFactory>()
- .to_factory(&|| {
- let user_manager: TransientPtr<dyn subjects::IUserManager> =
- TransientPtr::new(subjects::UserManager::new());
+ .to_factory(&|_| {
+ Box::new(|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
- user_manager
+ user_manager
+ })
})
.await?
.when_named("awesome")
@@ -1111,8 +1133,7 @@ mod tests
use crate as syrette;
#[crate::factory(threadsafe = true)]
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+ type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>;
mock! {
Provider {}
@@ -1130,21 +1151,37 @@ mod tests
}
}
- let mut di_container = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Factory(
- crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
- &|users| {
- let user_manager: TransientPtr<dyn IUserManager> =
- TransientPtr::new(UserManager::new(users));
+ mock_provider.expect_do_clone().returning(|| {
+ type FactoryFunc = Box<
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ >;
+
+ let mut inner_mock_provider = MockProvider::new();
- user_manager
- },
- )),
- ))
+ let factory_func: &'static (dyn Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = FactoryFunc> + Send + Sync) = &|_| {
+ Box::new(|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ })
+ };
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(
+ ThreadsafeCastableFactory::new(factory_func),
+ ),
+ ))
+ });
+
+ Box::new(inner_mock_provider)
});
{
@@ -1206,8 +1243,7 @@ mod tests
use crate as syrette;
#[crate::factory(threadsafe = true)]
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+ type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>;
mock! {
Provider {}
@@ -1217,32 +1253,54 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Factory(
- crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
- &|users| {
- let user_manager: TransientPtr<dyn IUserManager> =
- TransientPtr::new(UserManager::new(users));
+ mock_provider.expect_do_clone().returning(|| {
+ type FactoryFunc = Box<
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ >;
- user_manager
- },
- )),
- ))
+ let mut inner_mock_provider = MockProvider::new();
+
+ let factory_func: &'static (dyn Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = FactoryFunc> + Send + Sync) = &|_| {
+ Box::new(|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ })
+ };
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(
+ ThreadsafeCastableFactory::new(factory_func),
+ ),
+ ))
+ });
+
+ Box::new(inner_mock_provider)
});
- di_container
- .bindings
- .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+ }
di_container
.get_named::<IUserManagerFactory>("special")
diff --git a/src/provider/async.rs b/src/provider/async.rs
index 1ddb614..df96b27 100644
--- a/src/provider/async.rs
+++ b/src/provider/async.rs
@@ -20,6 +20,12 @@ pub enum AsyncProvidable
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
),
+ #[cfg(feature = "factory")]
+ DefaultFactory(
+ crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ),
}
#[async_trait]
@@ -150,6 +156,7 @@ pub struct AsyncFactoryProvider
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
+ is_default_factory: bool,
}
#[cfg(feature = "factory")]
@@ -159,9 +166,13 @@ impl AsyncFactoryProvider
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
+ is_default_factory: bool,
) -> Self
{
- Self { factory }
+ Self {
+ factory,
+ is_default_factory,
+ }
}
}
@@ -175,7 +186,11 @@ impl IAsyncProvider for AsyncFactoryProvider
_dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
- Ok(AsyncProvidable::Factory(self.factory.clone()))
+ Ok(if self.is_default_factory {
+ AsyncProvidable::DefaultFactory(self.factory.clone())
+ } else {
+ AsyncProvidable::Factory(self.factory.clone())
+ })
}
fn do_clone(&self) -> Box<dyn IAsyncProvider>
@@ -191,6 +206,7 @@ impl Clone for AsyncFactoryProvider
{
Self {
factory: self.factory.clone(),
+ is_default_factory: self.is_default_factory.clone(),
}
}
}