aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--examples/async-factory/main.rs62
-rw-r--r--macros/src/factory/build_declare_interfaces.rs2
-rw-r--r--macros/src/factory/declare_default_args.rs2
-rw-r--r--macros/src/lib.rs39
-rw-r--r--src/async_di_container.rs139
-rw-r--r--src/castable_factory/mod.rs2
-rw-r--r--src/castable_factory/threadsafe.rs4
-rw-r--r--src/interfaces/factory.rs9
-rw-r--r--src/lib.rs5
-rw-r--r--src/provider/async.rs37
11 files changed, 246 insertions, 57 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 360043a..54630aa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -49,7 +49,7 @@ tokio = { version = "1.20.1", features = ["sync"], optional = true }
mockall = "0.11.1"
anyhow = "1.0.62"
third-party-lib = { path = "./examples/with-3rd-party/third-party-lib" }
-tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread"] }
+tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread", "time"] }
[workspace]
members = [
diff --git a/examples/async-factory/main.rs b/examples/async-factory/main.rs
index 74e12c7..715abf5 100644
--- a/examples/async-factory/main.rs
+++ b/examples/async-factory/main.rs
@@ -2,10 +2,14 @@
#![deny(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
+use std::time::Duration;
+
use anyhow::Result;
-use syrette::{async_closure, factory, AsyncDIContainer};
+use syrette::ptr::TransientPtr;
+use syrette::{async_closure, declare_default_factory, factory, AsyncDIContainer};
+use tokio::time::sleep;
-trait IFoo
+trait IFoo: Send + Sync
{
fn bar(&self);
}
@@ -36,6 +40,34 @@ impl IFoo for Foo
}
}
+trait IPerson: Send + Sync
+{
+ fn name(&self) -> String;
+}
+
+struct Person
+{
+ name: String,
+}
+
+impl Person
+{
+ fn new(name: String) -> Self
+ {
+ Self { name }
+ }
+}
+
+impl IPerson for Person
+{
+ fn name(&self) -> String
+ {
+ self.name.clone()
+ }
+}
+
+declare_default_factory!(dyn IPerson, async = true);
+
#[tokio::main]
async fn main() -> Result<()>
{
@@ -45,9 +77,23 @@ async fn main() -> Result<()>
.bind::<IFooFactory>()
.to_async_factory(&|_| {
async_closure!(|cnt| {
- let foo = Box::new(Foo::new(cnt));
+ let foo_ptr = Box::new(Foo::new(cnt));
- foo as Box<dyn IFoo>
+ foo_ptr as Box<dyn IFoo>
+ })
+ })
+ .await?;
+
+ di_container
+ .bind::<dyn IPerson>()
+ .to_async_default_factory(&|_| {
+ async_closure!(|| {
+ // Do some time demanding thing...
+ sleep(Duration::from_secs(1)).await;
+
+ let person = TransientPtr::new(Person::new("Bob".to_string()));
+
+ person as TransientPtr<dyn IPerson>
})
})
.await?;
@@ -57,9 +103,13 @@ async fn main() -> Result<()>
.await?
.threadsafe_factory()?;
- let foo = foo_factory(4).await;
+ let foo_ptr = foo_factory(4).await;
+
+ foo_ptr.bar();
+
+ let person = di_container.get::<dyn IPerson>().await?.transient()?;
- foo.bar();
+ println!("Person name is {}", person.name());
Ok(())
}
diff --git a/macros/src/factory/build_declare_interfaces.rs b/macros/src/factory/build_declare_interfaces.rs
index ac4ddd6..61e162f 100644
--- a/macros/src/factory/build_declare_interfaces.rs
+++ b/macros/src/factory/build_declare_interfaces.rs
@@ -14,7 +14,7 @@ pub fn build_declare_factory_interfaces(
syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
(std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
#factory_interface
- > -> syrette::interfaces::factory::IFactory<
+ > -> syrette::interfaces::factory::IThreadsafeFactory<
(std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,),
#factory_interface
>,
diff --git a/macros/src/factory/declare_default_args.rs b/macros/src/factory/declare_default_args.rs
index 6450583..d19eba8 100644
--- a/macros/src/factory/declare_default_args.rs
+++ b/macros/src/factory/declare_default_args.rs
@@ -5,7 +5,7 @@ use syn::{Token, Type};
use crate::macro_flag::MacroFlag;
use crate::util::iterator_ext::IteratorExt;
-pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"];
+pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe", "async"];
pub struct DeclareDefaultFactoryMacroArgs
{
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 07ee7a5..172a113 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -6,16 +6,20 @@
use proc_macro::TokenStream;
use quote::quote;
-use syn::{parse, parse_macro_input, parse_str};
+use syn::{parse, parse_macro_input};
mod declare_interface_args;
-mod factory;
-mod fn_trait;
mod injectable;
mod libs;
mod macro_flag;
mod util;
+#[cfg(feature = "factory")]
+mod factory;
+
+#[cfg(feature = "factory")]
+mod fn_trait;
+
use crate::declare_interface_args::DeclareInterfaceArgs;
use crate::injectable::implementation::InjectableImpl;
use crate::injectable::macro_args::InjectableMacroArgs;
@@ -188,7 +192,7 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream
{
use quote::ToTokens;
- use syn::Type;
+ use syn::{parse_str, Type};
use crate::factory::build_declare_interfaces::build_declare_factory_interfaces;
use crate::factory::macro_args::FactoryMacroArgs;
@@ -266,6 +270,7 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
///
/// # Flags
/// - `threadsafe` - Mark as threadsafe.
+/// - `async` - Mark as async. Infers the `threadsafe` flag.
///
/// # Panics
/// If the provided arguments are invalid.
@@ -285,20 +290,40 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
#[cfg(feature = "factory")]
pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream
{
+ use syn::parse_str;
+
use crate::factory::build_declare_interfaces::build_declare_factory_interfaces;
use crate::factory::declare_default_args::DeclareDefaultFactoryMacroArgs;
use crate::fn_trait::FnTrait;
let DeclareDefaultFactoryMacroArgs { interface, flags } = parse(args_stream).unwrap();
- let is_threadsafe = flags
+ let mut is_threadsafe = flags
.iter()
.find(|flag| flag.flag.to_string().as_str() == "threadsafe")
.map_or(false, |flag| flag.is_on.value);
+ let is_async = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async")
+ .map_or(false, |flag| flag.is_on.value);
+
+ if is_async {
+ is_threadsafe = true;
+ }
+
let mut factory_interface: FnTrait = parse(
- quote! {
- dyn Fn() -> syrette::ptr::TransientPtr<#interface>
+ if is_async {
+ quote! {
+ dyn Fn() -> syrette::future::BoxFuture<
+ 'static,
+ syrette::ptr::TransientPtr<#interface>
+ >
+ }
+ } else {
+ quote! {
+ dyn Fn() -> syrette::ptr::TransientPtr<#interface>
+ }
}
.into(),
)
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
index d90cc0b..894b707 100644
--- a/src/async_di_container.rs
+++ b/src/async_di_container.rs
@@ -69,6 +69,7 @@ use crate::errors::async_di_container::{
AsyncBindingWhenConfiguratorError,
AsyncDIContainerError,
};
+use crate::future::BoxFuture;
use crate::interfaces::async_injectable::AsyncInjectable;
use crate::libs::intertrait::cast::error::CastError;
use crate::libs::intertrait::cast::{CastArc, CastBox};
@@ -274,6 +275,8 @@ where
> + Send
+ Sync,
{
+ use crate::provider::r#async::AsyncFactoryVariant;
+
let mut bindings_lock = self.di_container.bindings.lock().await;
if bindings_lock.has::<Interface>(None) {
@@ -289,7 +292,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
- false,
+ AsyncFactoryVariant::Normal,
)),
);
@@ -313,7 +316,8 @@ where
where
Args: 'static,
Return: 'static + ?Sized,
- Interface: Fn<Args, Output = crate::future::BoxFuture<'static, Return>>,
+ Interface:
+ Fn<Args, Output = crate::future::BoxFuture<'static, Return>> + Send + Sync,
FactoryFunc: Fn<
(Arc<AsyncDIContainer>,),
Output = Box<
@@ -324,6 +328,8 @@ where
> + Send
+ Sync,
{
+ use crate::provider::r#async::AsyncFactoryVariant;
+
let mut bindings_lock = self.di_container.bindings.lock().await;
if bindings_lock.has::<Interface>(None) {
@@ -339,7 +345,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
- false,
+ AsyncFactoryVariant::Normal,
)),
);
@@ -369,6 +375,58 @@ where
> + Send
+ Sync,
{
+ use crate::provider::r#async::AsyncFactoryVariant;
+
+ let mut bindings_lock = self.di_container.bindings.lock().await;
+
+ if bindings_lock.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let factory_impl = ThreadsafeCastableFactory::new(factory_func);
+
+ bindings_lock.set::<Interface>(
+ None,
+ Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
+ crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ AsyncFactoryVariant::Default,
+ )),
+ );
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container.clone()))
+ }
+
+ /// Creates a binding of factory type `Interface` to a async factory inside of the
+ /// associated [`AsyncDIContainer`].
+ ///
+ /// *This function is only available if Syrette is built with the "factory" and
+ /// "async" features.*
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ #[cfg(all(feature = "factory", feature = "async"))]
+ pub async fn to_async_default_factory<Return, FactoryFunc>(
+ &self,
+ factory_func: &'static FactoryFunc,
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
+ where
+ Return: 'static + ?Sized,
+ FactoryFunc: Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = Box<
+ (dyn Fn<(), Output = crate::future::BoxFuture<'static, Return>>
+ + Send
+ + Sync),
+ >,
+ > + Send
+ + Sync,
+ {
+ use crate::provider::r#async::AsyncFactoryVariant;
+
let mut bindings_lock = self.di_container.bindings.lock().await;
if bindings_lock.has::<Interface>(None) {
@@ -384,7 +442,7 @@ where
None,
Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
- true,
+ AsyncFactoryVariant::AsyncDefault,
)),
);
@@ -464,10 +522,10 @@ impl AsyncDIContainer
.get_binding_providable::<Interface>(name, dependency_history)
.await?;
- self.handle_binding_providable(binding_providable)
+ self.handle_binding_providable(binding_providable).await
}
- fn handle_binding_providable<Interface>(
+ async fn handle_binding_providable<Interface>(
self: &Arc<Self>,
binding_providable: AsyncProvidable,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
@@ -507,11 +565,10 @@ impl AsyncDIContainer
}
#[cfg(feature = "factory")]
AsyncProvidable::Factory(factory_binding) => {
+ use crate::interfaces::factory::IThreadsafeFactory;
+
let factory = factory_binding
- .cast::<dyn crate::interfaces::factory::IFactory<
- (Arc<AsyncDIContainer>,),
- Interface,
- >>()
+ .cast::<dyn IThreadsafeFactory<(Arc<AsyncDIContainer>,), Interface>>()
.map_err(|err| match err {
CastError::NotArcCastable(_) => {
AsyncDIContainerError::InterfaceNotAsync(
@@ -531,33 +588,59 @@ impl AsyncDIContainer
))
}
#[cfg(feature = "factory")]
- AsyncProvidable::DefaultFactory(default_factory_binding) => {
- use crate::interfaces::factory::IFactory;
+ AsyncProvidable::DefaultFactory(binding) => {
+ use crate::interfaces::factory::IThreadsafeFactory;
- let default_factory = default_factory_binding
- .cast::<dyn IFactory<
+ let default_factory = Self::cast_factory_binding::<
+ dyn IThreadsafeFactory<
(Arc<AsyncDIContainer>,),
dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync,
- >>()
- .map_err(|err| match err {
- CastError::NotArcCastable(_) => {
- AsyncDIContainerError::InterfaceNotAsync(
- type_name::<Interface>(),
- )
- }
- CastError::CastFailed { from: _, to: _ } => {
- AsyncDIContainerError::CastFailed {
- interface: type_name::<Interface>(),
- binding_kind: "default factory",
- }
- }
- })?;
+ >,
+ >(binding, "default factory")?;
Ok(SomeThreadsafePtr::Transient(default_factory(self.clone())()))
}
+ #[cfg(feature = "factory")]
+ AsyncProvidable::AsyncDefaultFactory(binding) => {
+ use crate::interfaces::factory::IThreadsafeFactory;
+
+ let async_default_factory = Self::cast_factory_binding::<
+ dyn IThreadsafeFactory<
+ (Arc<AsyncDIContainer>,),
+ dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>>
+ + Send
+ + Sync,
+ >,
+ >(
+ binding, "async default factory"
+ )?;
+
+ Ok(SomeThreadsafePtr::Transient(
+ async_default_factory(self.clone())().await,
+ ))
+ }
}
}
+ #[cfg(feature = "factory")]
+ fn cast_factory_binding<Type: 'static + ?Sized>(
+ factory_binding: Arc<dyn crate::interfaces::any_factory::AnyThreadsafeFactory>,
+ binding_kind: &'static str,
+ ) -> Result<Arc<Type>, AsyncDIContainerError>
+ {
+ factory_binding.cast::<Type>().map_err(|err| match err {
+ CastError::NotArcCastable(_) => {
+ AsyncDIContainerError::InterfaceNotAsync(type_name::<Type>())
+ }
+ CastError::CastFailed { from: _, to: _ } => {
+ AsyncDIContainerError::CastFailed {
+ interface: type_name::<Type>(),
+ binding_kind,
+ }
+ }
+ })
+ }
+
async fn get_binding_providable<Interface>(
self: &Arc<Self>,
name: Option<&'static str>,
diff --git a/src/castable_factory/mod.rs b/src/castable_factory/mod.rs
index 530cc82..e81b842 100644
--- a/src/castable_factory/mod.rs
+++ b/src/castable_factory/mod.rs
@@ -1,2 +1,4 @@
pub mod blocking;
+
+#[cfg(feature = "async")]
pub mod threadsafe;
diff --git a/src/castable_factory/threadsafe.rs b/src/castable_factory/threadsafe.rs
index 7be055c..b91dceb 100644
--- a/src/castable_factory/threadsafe.rs
+++ b/src/castable_factory/threadsafe.rs
@@ -1,6 +1,6 @@
#![allow(clippy::module_name_repetitions)]
use crate::interfaces::any_factory::{AnyFactory, AnyThreadsafeFactory};
-use crate::interfaces::factory::IFactory;
+use crate::interfaces::factory::IThreadsafeFactory;
use crate::ptr::TransientPtr;
pub struct ThreadsafeCastableFactory<Args, ReturnInterface>
@@ -26,7 +26,7 @@ where
}
}
-impl<Args, ReturnInterface> IFactory<Args, ReturnInterface>
+impl<Args, ReturnInterface> IThreadsafeFactory<Args, ReturnInterface>
for ThreadsafeCastableFactory<Args, ReturnInterface>
where
Args: 'static,
diff --git a/src/interfaces/factory.rs b/src/interfaces/factory.rs
index b09db36..de1fca9 100644
--- a/src/interfaces/factory.rs
+++ b/src/interfaces/factory.rs
@@ -9,3 +9,12 @@ where
ReturnInterface: 'static + ?Sized,
{
}
+
+/// Interface for a threadsafe factory.
+#[cfg(feature = "async")]
+pub trait IThreadsafeFactory<Args, ReturnInterface>:
+ Fn<Args, Output = TransientPtr<ReturnInterface>> + crate::libs::intertrait::CastFromSync
+where
+ ReturnInterface: 'static + ?Sized,
+{
+}
diff --git a/src/lib.rs b/src/lib.rs
index 247f907..082a93d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -125,4 +125,9 @@ macro_rules! async_closure {
Box::pin(async move { $($inner)* })
})
};
+ (|| { $($inner: stmt);* }) => {
+ Box::new(|| {
+ Box::pin(async move { $($inner)* })
+ })
+ };
}
diff --git a/src/provider/async.rs b/src/provider/async.rs
index df96b27..c9a5273 100644
--- a/src/provider/async.rs
+++ b/src/provider/async.rs
@@ -26,6 +26,12 @@ pub enum AsyncProvidable
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
),
+ #[cfg(feature = "factory")]
+ AsyncDefaultFactory(
+ crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ),
}
#[async_trait]
@@ -150,13 +156,21 @@ where
}
}
+#[derive(Clone, Copy)]
+pub enum AsyncFactoryVariant
+{
+ Normal,
+ Default,
+ AsyncDefault,
+}
+
#[cfg(feature = "factory")]
pub struct AsyncFactoryProvider
{
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
- is_default_factory: bool,
+ variant: AsyncFactoryVariant,
}
#[cfg(feature = "factory")]
@@ -166,13 +180,10 @@ impl AsyncFactoryProvider
factory: crate::ptr::ThreadsafeFactoryPtr<
dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
>,
- is_default_factory: bool,
+ variant: AsyncFactoryVariant,
) -> Self
{
- Self {
- factory,
- is_default_factory,
- }
+ Self { factory, variant }
}
}
@@ -186,10 +197,14 @@ impl IAsyncProvider for AsyncFactoryProvider
_dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, InjectableError>
{
- Ok(if self.is_default_factory {
- AsyncProvidable::DefaultFactory(self.factory.clone())
- } else {
- AsyncProvidable::Factory(self.factory.clone())
+ Ok(match self.variant {
+ AsyncFactoryVariant::Normal => AsyncProvidable::Factory(self.factory.clone()),
+ AsyncFactoryVariant::Default => {
+ AsyncProvidable::DefaultFactory(self.factory.clone())
+ }
+ AsyncFactoryVariant::AsyncDefault => {
+ AsyncProvidable::AsyncDefaultFactory(self.factory.clone())
+ }
})
}
@@ -206,7 +221,7 @@ impl Clone for AsyncFactoryProvider
{
Self {
factory: self.factory.clone(),
- is_default_factory: self.is_default_factory.clone(),
+ variant: self.variant,
}
}
}