aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-09-17 18:33:43 +0200
committerHampusM <hampus@hampusmat.com>2022-09-17 18:33:43 +0200
commit7de7f73963a266cceff85d6ab71c3256e5d382ec (patch)
tree67575870945b7ed0a5eeb99ccba79327598b3e02
parent8651f84f205da7a89f2fc7333d1dd8de0d80a22b (diff)
feat!: allow factories to access async DI container
BREAKING CHANGE: The to_factory & to_default_factory methods of AsyncBindingBuilder now expects a function returning a factory function
-rw-r--r--Cargo.toml2
-rw-r--r--examples/async/bootstrap.rs30
-rw-r--r--examples/async/interfaces/mod.rs1
-rw-r--r--examples/async/main.rs12
-rw-r--r--macros/src/lib.rs32
-rw-r--r--src/async_di_container.rs210
-rw-r--r--src/provider/async.rs20
7 files changed, 211 insertions, 96 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 6b5c37d..b9e406a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -27,7 +27,7 @@ required-features = ["factory"]
[[example]]
name = "async"
-required-features = ["async"]
+required-features = ["async", "factory"]
[dependencies]
syrette_macros = { path = "./macros", version = "0.3.0" }
diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs
index 7e1d2cd..51af067 100644
--- a/examples/async/bootstrap.rs
+++ b/examples/async/bootstrap.rs
@@ -2,17 +2,20 @@ use std::sync::Arc;
use anyhow::Result;
use syrette::async_di_container::AsyncDIContainer;
+use syrette::declare_default_factory;
+use syrette::ptr::TransientPtr;
-// Concrete implementations
use crate::animals::cat::Cat;
use crate::animals::dog::Dog;
use crate::animals::human::Human;
-//
-// Interfaces
+use crate::food::Food;
use crate::interfaces::cat::ICat;
use crate::interfaces::dog::IDog;
+use crate::interfaces::food::{IFood, IFoodFactory};
use crate::interfaces::human::IHuman;
+declare_default_factory!(dyn ICat, threadsafe = true);
+
pub async fn bootstrap() -> Result<Arc<AsyncDIContainer>>
{
let mut di_container = AsyncDIContainer::new();
@@ -24,8 +27,27 @@ pub async fn bootstrap() -> Result<Arc<AsyncDIContainer>>
.in_singleton_scope()
.await?;
- di_container.bind::<dyn ICat>().to::<Cat>().await?;
+ di_container
+ .bind::<dyn ICat>()
+ .to_default_factory(&|_| {
+ let cat: TransientPtr<dyn ICat> = TransientPtr::new(Cat::new());
+
+ cat
+ })
+ .await?;
+
di_container.bind::<dyn IHuman>().to::<Human>().await?;
+ di_container
+ .bind::<IFoodFactory>()
+ .to_factory(&|_| {
+ Box::new(|| {
+ let food: Box<dyn IFood> = Box::new(Food::new());
+
+ food
+ })
+ })
+ .await?;
+
Ok(di_container)
}
diff --git a/examples/async/interfaces/mod.rs b/examples/async/interfaces/mod.rs
index 5444978..ea0a26d 100644
--- a/examples/async/interfaces/mod.rs
+++ b/examples/async/interfaces/mod.rs
@@ -1,3 +1,4 @@
pub mod cat;
pub mod dog;
+pub mod food;
pub mod human;
diff --git a/examples/async/main.rs b/examples/async/main.rs
index 3c884fe..03e36e1 100644
--- a/examples/async/main.rs
+++ b/examples/async/main.rs
@@ -7,12 +7,15 @@ use tokio::spawn;
mod animals;
mod bootstrap;
+mod food;
mod interfaces;
use bootstrap::bootstrap;
use interfaces::dog::IDog;
use interfaces::human::IHuman;
+use crate::interfaces::food::IFoodFactory;
+
#[tokio::main]
async fn main() -> Result<()>
{
@@ -29,6 +32,15 @@ async fn main() -> Result<()>
dog.woof();
}
+ let food_factory = di_container
+ .get::<IFoodFactory>()
+ .await?
+ .threadsafe_factory()?;
+
+ let food = food_factory();
+
+ food.eat();
+
spawn(async move {
let human = di_container.get::<dyn IHuman>().await?.transient()?;
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 4c815db..79b1a1b 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -194,7 +194,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
let FactoryMacroArgs { flags } = parse(args_stream).unwrap();
- let is_async = flags
+ let is_threadsafe = flags
.iter()
.find(|flag| flag.flag.to_string().as_str() == "threadsafe")
.map_or(false, |flag| flag.is_on.value);
@@ -202,24 +202,27 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
let factory_type_alias::FactoryTypeAlias {
type_alias,
factory_interface,
- arg_types,
- return_type,
+ arg_types: _,
+ return_type: _,
} = parse(type_alias_stream).unwrap();
- let decl_interfaces = if is_async {
+ let decl_interfaces = if is_threadsafe {
quote! {
syrette::declare_interface!(
syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
- #arg_types,
- #return_type
- > -> #factory_interface,
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
+ #factory_interface
+ > -> syrette::interfaces::factory::IFactory<
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
+ #factory_interface
+ >,
async = true
);
syrette::declare_interface!(
syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
- #arg_types,
- #return_type
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
+ #factory_interface
> -> syrette::interfaces::any_factory::AnyThreadsafeFactory,
async = true
);
@@ -300,17 +303,20 @@ pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream
return quote! {
syrette::declare_interface!(
syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
- (),
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
#interface,
- > -> syrette::interfaces::factory::IFactory<(), #interface>,
+ > -> syrette::interfaces::factory::IFactory<
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
+ #interface
+ >,
async = true
);
syrette::declare_interface!(
syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
- (),
+ (std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
#interface,
- > -> syrette::interfaces::any_factory::AnyFactory,
+ > -> syrette::interfaces::any_factory::AnyThreadsafeFactory,
async = true
);
}
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
index 0cd92a5..ef0a540 100644
--- a/src/async_di_container.rs
+++ b/src/async_di_container.rs
@@ -39,7 +39,8 @@
//!
//! di_container
//! .bind::<dyn IDatabaseService>()
-//! .to::<DatabaseService>()?;
+//! .to::<DatabaseService>()
+//! .await?;
//!
//! let database_service = di_container
//! .get::<dyn IDatabaseService>()
@@ -259,16 +260,17 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub async fn to_factory<Args, Return>(
+ pub async fn to_factory<Args, Return, FactoryFunc>(
&self,
- factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>
- + Send
- + Sync),
+ factory_func: &'static FactoryFunc,
) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
where
Args: 'static,
Return: 'static + ?Sized,
- Interface: crate::interfaces::factory::IFactory<Args, Return>,
+ Interface: Fn<Args, Output = Return>,
+ FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = Box<(dyn Fn<Args, Output = Return>)>>
+ + Send
+ + Sync,
{
let mut bindings_lock = self.di_container.bindings.lock().await;
@@ -285,6 +287,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ false,
)),
);
@@ -300,14 +303,15 @@ where
/// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
/// the interface.
#[cfg(feature = "factory")]
- pub async fn to_default_factory<Return>(
+ pub async fn to_default_factory<Return, FactoryFunc>(
&self,
- factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>
- + Send
- + Sync),
+ factory_func: &'static FactoryFunc,
) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
where
Return: 'static + ?Sized,
+ FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync,
{
let mut bindings_lock = self.di_container.bindings.lock().await;
@@ -324,6 +328,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ true,
)),
);
@@ -402,10 +407,11 @@ impl AsyncDIContainer
.get_binding_providable::<Interface>(name, dependency_history)
.await?;
- Self::handle_binding_providable(binding_providable)
+ self.handle_binding_providable(binding_providable)
}
fn handle_binding_providable<Interface>(
+ self: &Arc<Self>,
binding_providable: AsyncProvidable,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
@@ -444,37 +450,49 @@ impl AsyncDIContainer
}
#[cfg(feature = "factory")]
AsyncProvidable::Factory(factory_binding) => {
- match factory_binding.clone().cast::<Interface>() {
- Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)),
- Err(first_err) => {
- use crate::interfaces::factory::IFactory;
-
- if let CastError::NotArcCastable(_) = first_err {
- return Err(AsyncDIContainerError::InterfaceNotAsync(
+ let factory = factory_binding
+ .cast::<dyn crate::interfaces::factory::IFactory<
+ (Arc<AsyncDIContainer>,),
+ Interface,
+ >>()
+ .map_err(|err| match err {
+ CastError::NotArcCastable(_) => {
+ AsyncDIContainerError::InterfaceNotAsync(
type_name::<Interface>(),
- ));
+ )
+ }
+ CastError::CastFailed { from: _, to: _ } => {
+ AsyncDIContainerError::CastFailed {
+ interface: type_name::<Interface>(),
+ binding_kind: "factory",
+ }
}
+ })?;
- let default_factory = factory_binding
- .cast::<dyn IFactory<(), Interface>>()
- .map_err(|err| match err {
- CastError::NotArcCastable(_) => {
- AsyncDIContainerError::InterfaceNotAsync(type_name::<
- Interface,
- >(
- ))
- }
- CastError::CastFailed { from: _, to: _ } => {
- AsyncDIContainerError::CastFailed {
- interface: type_name::<Interface>(),
- binding_kind: "factory",
- }
- }
- })?;
+ Ok(SomeThreadsafePtr::ThreadsafeFactory(
+ factory(self.clone()).into(),
+ ))
+ }
+ AsyncProvidable::DefaultFactory(default_factory_binding) => {
+ use crate::interfaces::factory::IFactory;
+
+ let default_factory = default_factory_binding
+ .cast::<dyn IFactory<(Arc<AsyncDIContainer>,), Interface>>()
+ .map_err(|err| match err {
+ CastError::NotArcCastable(_) => {
+ AsyncDIContainerError::InterfaceNotAsync(
+ type_name::<Interface>(),
+ )
+ }
+ CastError::CastFailed { from: _, to: _ } => {
+ AsyncDIContainerError::CastFailed {
+ interface: type_name::<Interface>(),
+ binding_kind: "default factory",
+ }
+ }
+ })?;
- Ok(SomeThreadsafePtr::Transient(default_factory()))
- }
- }
+ Ok(SomeThreadsafePtr::Transient(default_factory(self.clone())))
}
}
}
@@ -788,11 +806,13 @@ mod tests
di_container
.bind::<IUserManagerFactory>()
- .to_factory(&|| {
- let user_manager: TransientPtr<dyn subjects::IUserManager> =
- TransientPtr::new(subjects::UserManager::new());
+ .to_factory(&|_| {
+ Box::new(|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
- user_manager
+ user_manager
+ })
})
.await?;
@@ -818,11 +838,13 @@ mod tests
di_container
.bind::<IUserManagerFactory>()
- .to_factory(&|| {
- let user_manager: TransientPtr<dyn subjects::IUserManager> =
- TransientPtr::new(subjects::UserManager::new());
+ .to_factory(&|_| {
+ Box::new(|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
- user_manager
+ user_manager
+ })
})
.await?
.when_named("awesome")
@@ -1111,8 +1133,7 @@ mod tests
use crate as syrette;
#[crate::factory(threadsafe = true)]
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+ type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>;
mock! {
Provider {}
@@ -1130,21 +1151,37 @@ mod tests
}
}
- let mut di_container = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Factory(
- crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
- &|users| {
- let user_manager: TransientPtr<dyn IUserManager> =
- TransientPtr::new(UserManager::new(users));
+ mock_provider.expect_do_clone().returning(|| {
+ type FactoryFunc = Box<
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ >;
+
+ let mut inner_mock_provider = MockProvider::new();
- user_manager
- },
- )),
- ))
+ let factory_func: &'static (dyn Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = FactoryFunc> + Send + Sync) = &|_| {
+ Box::new(|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ })
+ };
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(
+ ThreadsafeCastableFactory::new(factory_func),
+ ),
+ ))
+ });
+
+ Box::new(inner_mock_provider)
});
{
@@ -1206,8 +1243,7 @@ mod tests
use crate as syrette;
#[crate::factory(threadsafe = true)]
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+ type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager>;
mock! {
Provider {}
@@ -1217,32 +1253,54 @@ mod tests
{
async fn provide(
&self,
- di_container: &AsyncDIContainer,
+ di_container: &Arc<AsyncDIContainer>,
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>;
+
+ fn do_clone(&self) -> Box<dyn IAsyncProvider>;
}
}
- let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+ let di_container = AsyncDIContainer::new();
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_, _| {
- Ok(AsyncProvidable::Factory(
- crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
- &|users| {
- let user_manager: TransientPtr<dyn IUserManager> =
- TransientPtr::new(UserManager::new(users));
+ mock_provider.expect_do_clone().returning(|| {
+ type FactoryFunc = Box<
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ >;
- user_manager
- },
- )),
- ))
+ let mut inner_mock_provider = MockProvider::new();
+
+ let factory_func: &'static (dyn Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = FactoryFunc> + Send + Sync) = &|_| {
+ Box::new(|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ })
+ };
+
+ inner_mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(
+ ThreadsafeCastableFactory::new(factory_func),
+ ),
+ ))
+ });
+
+ Box::new(inner_mock_provider)
});
- di_container
- .bindings
- .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+ {
+ di_container
+ .bindings
+ .lock()
+ .await
+ .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+ }
di_container
.get_named::<IUserManagerFactory>("special")
diff --git a/src/provider/async.rs b/src/provider/async.rs
index 1ddb614..df96b27 100644
--- a/src/provider/async.rs
+++ b/src/provider/async.rs
@@ -20,6 +20,12 @@ pub enum AsyncProvidable
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
),
+ #[cfg(feature = "factory")]
+ DefaultFactory(
+ crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ),
}
#[async_trait]
@@ -150,6 +156,7 @@ pub struct AsyncFactoryProvider
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
+ is_default_factory: bool,
}
#[cfg(feature = "factory")]
@@ -159,9 +166,13 @@ impl AsyncFactoryProvider
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
+ is_default_factory: bool,
) -> Self
{
- Self { factory }
+ Self {
+ factory,
+ is_default_factory,
+ }
}
}
@@ -175,7 +186,11 @@ impl IAsyncProvider for AsyncFactoryProvider
_dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
- Ok(AsyncProvidable::Factory(self.factory.clone()))
+ Ok(if self.is_default_factory {
+ AsyncProvidable::DefaultFactory(self.factory.clone())
+ } else {
+ AsyncProvidable::Factory(self.factory.clone())
+ })
}
fn do_clone(&self) -> Box<dyn IAsyncProvider>
@@ -191,6 +206,7 @@ impl Clone for AsyncFactoryProvider
{
Self {
factory: self.factory.clone(),
+ is_default_factory: self.is_default_factory.clone(),
}
}
}