aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-08-29 20:52:56 +0200
committerHampusM <hampus@hampusmat.com>2022-08-29 21:01:32 +0200
commit080cc42bb1da09059dbc35049a7ded0649961e0c (patch)
tree307ee564124373616022c1ba2b4d5af80845cd92 /src
parent6e31d8f9e46fece348f329763b39b9c6f2741c07 (diff)
feat: implement async functionality
Diffstat (limited to 'src')
-rw-r--r--src/async_di_container.rs1110
-rw-r--r--src/castable_factory/blocking.rs (renamed from src/castable_factory.rs)0
-rw-r--r--src/castable_factory/mod.rs2
-rw-r--r--src/castable_factory/threadsafe.rs88
-rw-r--r--src/di_container.rs30
-rw-r--r--src/di_container_binding_map.rs38
-rw-r--r--src/errors/async_di_container.rs79
-rw-r--r--src/errors/injectable.rs14
-rw-r--r--src/errors/mod.rs3
-rw-r--r--src/errors/ptr.rs18
-rw-r--r--src/interfaces/any_factory.rs13
-rw-r--r--src/interfaces/async_injectable.rs35
-rw-r--r--src/interfaces/mod.rs3
-rw-r--r--src/lib.rs38
-rw-r--r--src/libs/intertrait/mod.rs7
-rw-r--r--src/libs/mod.rs2
-rw-r--r--src/provider/async.rs135
-rw-r--r--src/provider/blocking.rs (renamed from src/provider.rs)0
-rw-r--r--src/provider/mod.rs4
-rw-r--r--src/ptr.rs89
20 files changed, 1617 insertions, 91 deletions
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
new file mode 100644
index 0000000..374746f
--- /dev/null
+++ b/src/async_di_container.rs
@@ -0,0 +1,1110 @@
+//! Asynchronous dependency injection container.
+//!
+//! # Examples
+//! ```
+//! use std::collections::HashMap;
+//! use std::error::Error;
+//!
+//! use syrette::{injectable, AsyncDIContainer};
+//!
+//! trait IDatabaseService
+//! {
+//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>;
+//! }
+//!
+//! struct DatabaseService {}
+//!
+//! #[injectable(IDatabaseService, { async = true })]
+//! impl DatabaseService
+//! {
+//! fn new() -> Self
+//! {
+//! Self {}
+//! }
+//! }
+//!
+//! impl IDatabaseService for DatabaseService
+//! {
+//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>
+//! {
+//! // Do stuff here
+//! HashMap::<String, String>::new()
+//! }
+//! }
+//!
+//! #[tokio::main]
+//! async fn main() -> Result<(), Box<dyn Error>>
+//! {
+//! let mut di_container = AsyncDIContainer::new();
+//!
+//! di_container
+//! .bind::<dyn IDatabaseService>()
+//! .to::<DatabaseService>()?;
+//!
+//! let database_service = di_container
+//! .get::<dyn IDatabaseService>()
+//! .await?
+//! .transient()?;
+//!
+//! Ok(())
+//! }
+//! ```
+//!
+//! ---
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+use std::any::type_name;
+use std::marker::PhantomData;
+
+#[cfg(feature = "factory")]
+use crate::castable_factory::threadsafe::ThreadsafeCastableFactory;
+use crate::di_container_binding_map::DIContainerBindingMap;
+use crate::errors::async_di_container::{
+ AsyncBindingBuilderError,
+ AsyncBindingScopeConfiguratorError,
+ AsyncBindingWhenConfiguratorError,
+ AsyncDIContainerError,
+};
+use crate::interfaces::async_injectable::AsyncInjectable;
+use crate::libs::intertrait::cast::{CastArc, CastBox};
+use crate::provider::r#async::{
+ AsyncProvidable,
+ AsyncSingletonProvider,
+ AsyncTransientTypeProvider,
+ IAsyncProvider,
+};
+use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr};
+
+/// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
+pub struct AsyncBindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+}
+
+impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ }
+ }
+
+ /// Configures the binding to have a name.
+ ///
+ /// # Errors
+ /// Will return Err if no binding for the interface already exists.
+ pub fn when_named(
+ &mut self,
+ name: &'static str,
+ ) -> Result<(), AsyncBindingWhenConfiguratorError>
+ {
+ let binding = self
+ .di_container
+ .bindings
+ .remove::<Interface>(None)
+ .map_or_else(
+ || {
+ Err(AsyncBindingWhenConfiguratorError::BindingNotFound(
+ type_name::<Interface>(),
+ ))
+ },
+ Ok,
+ )?;
+
+ self.di_container
+ .bindings
+ .set::<Interface>(Some(name), binding);
+
+ Ok(())
+ }
+}
+
+/// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
+pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+where
+ Interface: 'static + ?Sized,
+ Implementation: AsyncInjectable,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+ implementation_phantom: PhantomData<Implementation>,
+}
+
+impl<'di_container, Interface, Implementation>
+ AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+where
+ Interface: 'static + ?Sized,
+ Implementation: AsyncInjectable,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ implementation_phantom: PhantomData,
+ }
+ }
+
+ /// Configures the binding to be in a transient scope.
+ ///
+ /// This is the default.
+ pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator<Interface>
+ {
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(AsyncTransientTypeProvider::<Implementation>::new()),
+ );
+
+ AsyncBindingWhenConfigurator::new(self.di_container)
+ }
+
+ /// Configures the binding to be in a singleton scope.
+ ///
+ /// # Errors
+ /// Will return Err if resolving the implementation fails.
+ pub async fn in_singleton_scope(
+ &mut self,
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingScopeConfiguratorError>
+ {
+ let singleton: ThreadsafeSingletonPtr<Implementation> =
+ ThreadsafeSingletonPtr::from(
+ Implementation::resolve(self.di_container, Vec::new())
+ .await
+ .map_err(
+ AsyncBindingScopeConfiguratorError::SingletonResolveFailed,
+ )?,
+ );
+
+ self.di_container
+ .bindings
+ .set::<Interface>(None, Box::new(AsyncSingletonProvider::new(singleton)));
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+}
+
+/// Binding builder for type `Interface` inside a [`AsyncDIContainer`].
+pub struct AsyncBindingBuilder<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+}
+
+impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ }
+ }
+
+ /// Creates a binding of type `Interface` to type `Implementation` inside of the
+ /// associated [`AsyncDIContainer`].
+ ///
+ /// The scope of the binding is transient. But that can be changed by using the
+ /// returned [`AsyncBindingScopeConfigurator`]
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ pub fn to<Implementation>(
+ &mut self,
+ ) -> Result<
+ AsyncBindingScopeConfigurator<Interface, Implementation>,
+ AsyncBindingBuilderError,
+ >
+ where
+ Implementation: AsyncInjectable,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let mut binding_scope_configurator =
+ AsyncBindingScopeConfigurator::new(self.di_container);
+
+ binding_scope_configurator.in_transient_scope();
+
+ Ok(binding_scope_configurator)
+ }
+
+ /// Creates a binding of factory type `Interface` to a factory inside of the
+ /// associated [`AsyncDIContainer`].
+ ///
+ /// *This function is only available if Syrette is built with the "factory" feature.*
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ #[cfg(feature = "factory")]
+ pub fn to_factory<Args, Return>(
+ &mut self,
+ factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync),
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
+ where
+ Args: 'static,
+ Return: 'static + ?Sized,
+ Interface: crate::interfaces::factory::IFactory<Args, Return>,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let factory_impl = ThreadsafeCastableFactory::new(factory_func);
+
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
+ crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ )),
+ );
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+
+ /// Creates a binding of type `Interface` to a factory that takes no arguments
+ /// inside of the associated [`AsyncDIContainer`].
+ ///
+ /// *This function is only available if Syrette is built with the "factory" feature.*
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ #[cfg(feature = "factory")]
+ pub fn to_default_factory<Return>(
+ &mut self,
+ factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync),
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
+ where
+ Return: 'static + ?Sized,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let factory_impl = ThreadsafeCastableFactory::new(factory_func);
+
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
+ crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ )),
+ );
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+}
+
+/// Dependency injection container.
+pub struct AsyncDIContainer
+{
+ bindings: DIContainerBindingMap<dyn IAsyncProvider>,
+}
+
+impl AsyncDIContainer
+{
+ /// Returns a new `AsyncDIContainer`.
+ #[must_use]
+ pub fn new() -> Self
+ {
+ Self {
+ bindings: DIContainerBindingMap::new(),
+ }
+ }
+
+ /// Returns a new [`AsyncBindingBuilder`] for the given interface.
+ pub fn bind<Interface>(&mut self) -> AsyncBindingBuilder<Interface>
+ where
+ Interface: 'static + ?Sized,
+ {
+ AsyncBindingBuilder::<Interface>::new(self)
+ }
+
+ /// Returns the type bound with `Interface`.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` exists
+ /// - Resolving the binding for fails
+ /// - Casting the binding for fails
+ pub async fn get<Interface>(
+ &self,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.get_bound::<Interface>(Vec::new(), None).await
+ }
+
+ /// Returns the type bound with `Interface` and the specified name.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` with name `name` exists
+ /// - Resolving the binding fails
+ /// - Casting the binding for fails
+ pub async fn get_named<Interface>(
+ &self,
+ name: &'static str,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.get_bound::<Interface>(Vec::new(), Some(name)).await
+ }
+
+ #[doc(hidden)]
+ pub async fn get_bound<Interface>(
+ &self,
+ dependency_history: Vec<&'static str>,
+ name: Option<&'static str>,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ let binding_providable = self
+ .get_binding_providable::<Interface>(name, dependency_history)
+ .await?;
+
+ Self::handle_binding_providable(binding_providable)
+ }
+
+ fn handle_binding_providable<Interface>(
+ binding_providable: AsyncProvidable,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ match binding_providable {
+ AsyncProvidable::Transient(transient_binding) => {
+ Ok(SomeThreadsafePtr::Transient(
+ transient_binding.cast::<Interface>().map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ ))
+ }
+ AsyncProvidable::Singleton(singleton_binding) => {
+ Ok(SomeThreadsafePtr::ThreadsafeSingleton(
+ singleton_binding.cast::<Interface>().map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ ))
+ }
+ #[cfg(feature = "factory")]
+ AsyncProvidable::Factory(factory_binding) => {
+ match factory_binding.clone().cast::<Interface>() {
+ Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)),
+ Err(_err) => {
+ use crate::interfaces::factory::IFactory;
+
+ let default_factory =
+ factory_binding
+ .cast::<dyn IFactory<(), Interface>>()
+ .map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<
+ Interface,
+ >(
+ ))
+ })?;
+
+ Ok(SomeThreadsafePtr::Transient(default_factory()))
+ }
+ }
+ }
+ }
+ }
+
+ async fn get_binding_providable<Interface>(
+ &self,
+ name: Option<&'static str>,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.bindings
+ .get::<Interface>(name)
+ .map_or_else(
+ || {
+ Err(AsyncDIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })
+ },
+ Ok,
+ )?
+ .provide(self, dependency_history)
+ .await
+ .map_err(|err| AsyncDIContainerError::BindingResolveFailed {
+ reason: err,
+ interface: type_name::<Interface>(),
+ })
+ }
+}
+
+impl Default for AsyncDIContainer
+{
+ fn default() -> Self
+ {
+ Self::new()
+ }
+}
+
+#[cfg(test)]
+mod tests
+{
+ use std::error::Error;
+
+ use async_trait::async_trait;
+ use mockall::mock;
+
+ use super::*;
+ use crate::errors::injectable::InjectableError;
+ use crate::ptr::TransientPtr;
+
+ mod subjects
+ {
+ //! Test subjects.
+
+ use std::fmt::Debug;
+
+ use async_trait::async_trait;
+ use syrette_macros::declare_interface;
+
+ use super::AsyncDIContainer;
+ use crate::interfaces::async_injectable::AsyncInjectable;
+ use crate::ptr::TransientPtr;
+
+ pub trait IUserManager
+ {
+ fn add_user(&self, user_id: i128);
+
+ fn remove_user(&self, user_id: i128);
+ }
+
+ pub struct UserManager {}
+
+ impl UserManager
+ {
+ pub fn new() -> Self
+ {
+ Self {}
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+
+ fn remove_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+ }
+
+ use crate as syrette;
+
+ declare_interface!(UserManager -> IUserManager);
+
+ #[async_trait]
+ impl AsyncInjectable for UserManager
+ {
+ async fn resolve(
+ _: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
+ where
+ Self: Sized,
+ {
+ Ok(TransientPtr::new(Self::new()))
+ }
+ }
+
+ pub trait INumber
+ {
+ fn get(&self) -> i32;
+
+ fn set(&mut self, number: i32);
+ }
+
+ impl PartialEq for dyn INumber
+ {
+ fn eq(&self, other: &Self) -> bool
+ {
+ self.get() == other.get()
+ }
+ }
+
+ impl Debug for dyn INumber
+ {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str(format!("{}", self.get()).as_str())
+ }
+ }
+
+ pub struct Number
+ {
+ pub num: i32,
+ }
+
+ impl Number
+ {
+ pub fn new() -> Self
+ {
+ Self { num: 0 }
+ }
+ }
+
+ impl INumber for Number
+ {
+ fn get(&self) -> i32
+ {
+ self.num
+ }
+
+ fn set(&mut self, number: i32)
+ {
+ self.num = number;
+ }
+ }
+
+ declare_interface!(Number -> INumber, async = true);
+
+ #[async_trait]
+ impl AsyncInjectable for Number
+ {
+ async fn resolve(
+ _: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
+ where
+ Self: Sized,
+ {
+ Ok(TransientPtr::new(Self::new()))
+ }
+ }
+ }
+
+ #[test]
+ fn can_bind_to() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_bind_to_transient() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope();
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope()
+ .when_named("regular")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_bind_to_singleton() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_singleton_scope()
+ .await?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_singleton_scope()
+ .await?
+ .when_named("cool")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "factory")]
+ fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
+ {
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container.bind::<IUserManagerFactory>().to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
+
+ user_manager
+ })?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "factory")]
+ fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
+ {
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<IUserManagerFactory>()
+ .to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
+
+ user_manager
+ })?
+ .when_named("awesome")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider));
+
+ di_container
+ .get::<dyn subjects::IUserManager>()
+ .await?
+ .transient()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider));
+
+ di_container
+ .get_named::<dyn subjects::IUserManager>("special")
+ .await?
+ .transient()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_singleton() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new());
+
+ ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .set::<dyn subjects::INumber>(None, Box::new(mock_provider));
+
+ let first_number_rc = di_container
+ .get::<dyn subjects::INumber>()
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container
+ .get::<dyn subjects::INumber>()
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_singleton_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new());
+
+ ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider));
+
+ let first_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ #[cfg(feature = "factory")]
+ async fn can_get_factory() -> Result<(), Box<dyn Error>>
+ {
+ trait IUserManager
+ {
+ fn add_user(&mut self, user_id: i128);
+
+ fn remove_user(&mut self, user_id: i128);
+ }
+
+ struct UserManager
+ {
+ users: Vec<i128>,
+ }
+
+ impl UserManager
+ {
+ fn new(users: Vec<i128>) -> Self
+ {
+ Self { users }
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&mut self, user_id: i128)
+ {
+ self.users.push(user_id);
+ }
+
+ fn remove_user(&mut self, user_id: i128)
+ {
+ let user_index =
+ self.users.iter().position(|user| *user == user_id).unwrap();
+
+ self.users.remove(user_index);
+ }
+ }
+
+ use crate as syrette;
+
+ #[crate::factory(async = true)]
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
+ &|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ },
+ )),
+ ))
+ });
+
+ di_container
+ .bindings
+ .set::<IUserManagerFactory>(None, Box::new(mock_provider));
+
+ di_container
+ .get::<IUserManagerFactory>()
+ .await?
+ .threadsafe_factory()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ #[cfg(feature = "factory")]
+ async fn can_get_factory_named() -> Result<(), Box<dyn Error>>
+ {
+ trait IUserManager
+ {
+ fn add_user(&mut self, user_id: i128);
+
+ fn remove_user(&mut self, user_id: i128);
+ }
+
+ struct UserManager
+ {
+ users: Vec<i128>,
+ }
+
+ impl UserManager
+ {
+ fn new(users: Vec<i128>) -> Self
+ {
+ Self { users }
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&mut self, user_id: i128)
+ {
+ self.users.push(user_id);
+ }
+
+ fn remove_user(&mut self, user_id: i128)
+ {
+ let user_index =
+ self.users.iter().position(|user| *user == user_id).unwrap();
+
+ self.users.remove(user_index);
+ }
+ }
+
+ use crate as syrette;
+
+ #[crate::factory(async = true)]
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
+ &|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ },
+ )),
+ ))
+ });
+
+ di_container
+ .bindings
+ .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+
+ di_container
+ .get_named::<IUserManagerFactory>("special")
+ .await?
+ .threadsafe_factory()?;
+
+ Ok(())
+ }
+}
diff --git a/src/castable_factory.rs b/src/castable_factory/blocking.rs
index 5ff4db0..5ff4db0 100644
--- a/src/castable_factory.rs
+++ b/src/castable_factory/blocking.rs
diff --git a/src/castable_factory/mod.rs b/src/castable_factory/mod.rs
new file mode 100644
index 0000000..530cc82
--- /dev/null
+++ b/src/castable_factory/mod.rs
@@ -0,0 +1,2 @@
+pub mod blocking;
+pub mod threadsafe;
diff --git a/src/castable_factory/threadsafe.rs b/src/castable_factory/threadsafe.rs
new file mode 100644
index 0000000..7be055c
--- /dev/null
+++ b/src/castable_factory/threadsafe.rs
@@ -0,0 +1,88 @@
+#![allow(clippy::module_name_repetitions)]
+use crate::interfaces::any_factory::{AnyFactory, AnyThreadsafeFactory};
+use crate::interfaces::factory::IFactory;
+use crate::ptr::TransientPtr;
+
+pub struct ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync),
+}
+
+impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ pub fn new(
+ func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>>
+ + Send
+ + Sync),
+ ) -> Self
+ {
+ Self { func }
+ }
+}
+
+impl<Args, ReturnInterface> IFactory<Args, ReturnInterface>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
+
+impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ extern "rust-call" fn call(&self, args: Args) -> Self::Output
+ {
+ self.func.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> FnMut<Args>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output
+ {
+ self.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> FnOnce<Args>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ type Output = TransientPtr<ReturnInterface>;
+
+ extern "rust-call" fn call_once(self, args: Args) -> Self::Output
+ {
+ self.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> AnyFactory
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
+
+impl<Args, ReturnInterface> AnyThreadsafeFactory
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
diff --git a/src/di_container.rs b/src/di_container.rs
index e42175b..b0e5af1 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -1,4 +1,4 @@
-//! Dependency injection container and other related utilities.
+//! Dependency injection container.
//!
//! # Examples
//! ```
@@ -53,7 +53,7 @@ use std::any::type_name;
use std::marker::PhantomData;
#[cfg(feature = "factory")]
-use crate::castable_factory::CastableFactory;
+use crate::castable_factory::blocking::CastableFactory;
use crate::di_container_binding_map::DIContainerBindingMap;
use crate::errors::di_container::{
BindingBuilderError,
@@ -63,7 +63,12 @@ use crate::errors::di_container::{
};
use crate::interfaces::injectable::Injectable;
use crate::libs::intertrait::cast::{CastBox, CastRc};
-use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};
+use crate::provider::blocking::{
+ IProvider,
+ Providable,
+ SingletonProvider,
+ TransientTypeProvider,
+};
use crate::ptr::{SingletonPtr, SomePtr};
/// When configurator for a binding for type 'Interface' inside a [`DIContainer`].
@@ -256,7 +261,7 @@ where
self.di_container.bindings.set::<Interface>(
None,
- Box::new(crate::provider::FactoryProvider::new(
+ Box::new(crate::provider::blocking::FactoryProvider::new(
crate::ptr::FactoryPtr::new(factory_impl),
)),
);
@@ -290,7 +295,7 @@ where
self.di_container.bindings.set::<Interface>(
None,
- Box::new(crate::provider::FactoryProvider::new(
+ Box::new(crate::provider::blocking::FactoryProvider::new(
crate::ptr::FactoryPtr::new(factory_impl),
)),
);
@@ -302,7 +307,7 @@ where
/// Dependency injection container.
pub struct DIContainer
{
- bindings: DIContainerBindingMap,
+ bindings: DIContainerBindingMap<dyn IProvider>,
}
impl DIContainer
@@ -416,7 +421,16 @@ impl DIContainer
Interface: 'static + ?Sized,
{
self.bindings
- .get::<Interface>(name)?
+ .get::<Interface>(name)
+ .map_or_else(
+ || {
+ Err(DIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })
+ },
+ Ok,
+ )?
.provide(self, dependency_history)
.map_err(|err| DIContainerError::BindingResolveFailed {
reason: err,
@@ -442,7 +456,7 @@ mod tests
use super::*;
use crate::errors::injectable::InjectableError;
- use crate::provider::IProvider;
+ use crate::provider::blocking::IProvider;
use crate::ptr::TransientPtr;
mod subjects
diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs
index 4df889d..4aa246e 100644
--- a/src/di_container_binding_map.rs
+++ b/src/di_container_binding_map.rs
@@ -1,10 +1,7 @@
-use std::any::{type_name, TypeId};
+use std::any::TypeId;
use ahash::AHashMap;
-use crate::errors::di_container::DIContainerError;
-use crate::provider::IProvider;
-
#[derive(Debug, PartialEq, Eq, Hash)]
struct DIContainerBindingKey
{
@@ -12,12 +9,16 @@ struct DIContainerBindingKey
name: Option<&'static str>,
}
-pub struct DIContainerBindingMap
+pub struct DIContainerBindingMap<Provider>
+where
+ Provider: 'static + ?Sized,
{
- bindings: AHashMap<DIContainerBindingKey, Box<dyn IProvider>>,
+ bindings: AHashMap<DIContainerBindingKey, Box<Provider>>,
}
-impl DIContainerBindingMap
+impl<Provider> DIContainerBindingMap<Provider>
+where
+ Provider: 'static + ?Sized,
{
pub fn new() -> Self
{
@@ -26,33 +27,22 @@ impl DIContainerBindingMap
}
}
- pub fn get<Interface>(
- &self,
- name: Option<&'static str>,
- ) -> Result<&dyn IProvider, DIContainerError>
+ pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Provider>
where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
- Ok(self
- .bindings
+ self.bindings
.get(&DIContainerBindingKey {
type_id: interface_typeid,
name,
})
- .ok_or_else(|| DIContainerError::BindingNotFound {
- interface: type_name::<Interface>(),
- name,
- })?
- .as_ref())
+ .map(|provider| provider.as_ref())
}
- pub fn set<Interface>(
- &mut self,
- name: Option<&'static str>,
- provider: Box<dyn IProvider>,
- ) where
+ pub fn set<Interface>(&mut self, name: Option<&'static str>, provider: Box<Provider>)
+ where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
@@ -69,7 +59,7 @@ impl DIContainerBindingMap
pub fn remove<Interface>(
&mut self,
name: Option<&'static str>,
- ) -> Option<Box<dyn IProvider>>
+ ) -> Option<Box<Provider>>
where
Interface: 'static + ?Sized,
{
diff --git a/src/errors/async_di_container.rs b/src/errors/async_di_container.rs
new file mode 100644
index 0000000..4f5e50a
--- /dev/null
+++ b/src/errors/async_di_container.rs
@@ -0,0 +1,79 @@
+//! Error types for [`AsyncDIContainer`] and it's related structs.
+//!
+//! ---
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+//!
+//! [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer
+
+use crate::errors::injectable::InjectableError;
+
+/// Error type for [`AsyncDIContainer`].
+///
+/// [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncDIContainerError
+{
+ /// Unable to cast a binding for a interface.
+ #[error("Unable to cast binding for interface '{0}'")]
+ CastFailed(&'static str),
+
+ /// Failed to resolve a binding for a interface.
+ #[error("Failed to resolve binding for interface '{interface}'")]
+ BindingResolveFailed
+ {
+ /// The reason for the problem.
+ #[source]
+ reason: InjectableError,
+
+ /// The affected bound interface.
+ interface: &'static str,
+ },
+
+ /// No binding exists for a interface (and optionally a name).
+ #[error(
+ "No binding exists for interface '{interface}' {}",
+ .name.map_or_else(String::new, |name| format!("with name '{}'", name))
+ )]
+ BindingNotFound
+ {
+ /// The interface that doesn't have a binding.
+ interface: &'static str,
+
+ /// The name of the binding if one exists.
+ name: Option<&'static str>,
+ },
+}
+
+/// Error type for [`AsyncBindingBuilder`].
+///
+/// [`AsyncBindingBuilder`]: crate::async_di_container::AsyncBindingBuilder
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingBuilderError
+{
+ /// A binding already exists for a interface.
+ #[error("Binding already exists for interface '{0}'")]
+ BindingAlreadyExists(&'static str),
+}
+
+/// Error type for [`AsyncBindingScopeConfigurator`].
+///
+/// [`AsyncBindingScopeConfigurator`]: crate::async_di_container::AsyncBindingScopeConfigurator
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingScopeConfiguratorError
+{
+ /// Resolving a singleton failed.
+ #[error("Resolving the given singleton failed")]
+ SingletonResolveFailed(#[from] InjectableError),
+}
+
+/// Error type for [`AsyncBindingWhenConfigurator`].
+///
+/// [`AsyncBindingWhenConfigurator`]: crate::async_di_container::AsyncBindingWhenConfigurator
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingWhenConfiguratorError
+{
+ /// A binding for a interface wasn't found.
+ #[error("A binding for interface '{0}' wasn't found'")]
+ BindingNotFound(&'static str),
+}
diff --git a/src/errors/injectable.rs b/src/errors/injectable.rs
index 4b9af96..ed161cb 100644
--- a/src/errors/injectable.rs
+++ b/src/errors/injectable.rs
@@ -3,7 +3,7 @@
//!
//! [`Injectable`]: crate::interfaces::injectable::Injectable
-use super::di_container::DIContainerError;
+use crate::errors::di_container::DIContainerError;
/// Error type for structs that implement [`Injectable`].
///
@@ -23,6 +23,18 @@ pub enum InjectableError
affected: &'static str,
},
+ /// Failed to resolve dependencies.
+ #[cfg(feature = "async")]
+ #[error("Failed to resolve a dependency of '{affected}'")]
+ AsyncResolveFailed
+ {
+ /// The reason for the problem.
+ #[source]
+ reason: Box<crate::errors::async_di_container::AsyncDIContainerError>,
+
+ /// The affected injectable type.
+ affected: &'static str,
+ },
/// Detected circular dependencies.
#[error("Detected circular dependencies. {dependency_trace}")]
DetectedCircular
diff --git a/src/errors/mod.rs b/src/errors/mod.rs
index 7d66ddf..c3930b0 100644
--- a/src/errors/mod.rs
+++ b/src/errors/mod.rs
@@ -3,3 +3,6 @@
pub mod di_container;
pub mod injectable;
pub mod ptr;
+
+#[cfg(feature = "async")]
+pub mod async_di_container;
diff --git a/src/errors/ptr.rs b/src/errors/ptr.rs
index e0c3d05..56621c1 100644
--- a/src/errors/ptr.rs
+++ b/src/errors/ptr.rs
@@ -17,3 +17,21 @@ pub enum SomePtrError
found: &'static str,
},
}
+
+/// Error type for [`SomeThreadsafePtr`].
+///
+/// [`SomeThreadsafePtr`]: crate::ptr::SomeThreadsafePtr
+#[derive(thiserror::Error, Debug)]
+pub enum SomeThreadsafePtrError
+{
+ /// Tried to get as a wrong threadsafe smart pointer type.
+ #[error("Wrong threadsafe smart pointer type. Expected {expected}, found {found}")]
+ WrongPtrType
+ {
+ /// The expected smart pointer type.
+ expected: &'static str,
+
+ /// The found smart pointer type.
+ found: &'static str,
+ },
+}
diff --git a/src/interfaces/any_factory.rs b/src/interfaces/any_factory.rs
index 887bb61..1bf9208 100644
--- a/src/interfaces/any_factory.rs
+++ b/src/interfaces/any_factory.rs
@@ -2,7 +2,7 @@
use std::fmt::Debug;
-use crate::libs::intertrait::CastFrom;
+use crate::libs::intertrait::{CastFrom, CastFromSync};
/// Interface for any factory to ever exist.
pub trait AnyFactory: CastFrom {}
@@ -14,3 +14,14 @@ impl Debug for dyn AnyFactory
f.write_str("{}")
}
}
+
+/// Interface for any threadsafe factory to ever exist.
+pub trait AnyThreadsafeFactory: CastFromSync {}
+
+impl Debug for dyn AnyThreadsafeFactory
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str("{}")
+ }
+}
diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs
new file mode 100644
index 0000000..badc3c5
--- /dev/null
+++ b/src/interfaces/async_injectable.rs
@@ -0,0 +1,35 @@
+//! Interface for structs that can be injected into or be injected to.
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+use std::fmt::Debug;
+
+use async_trait::async_trait;
+
+use crate::async_di_container::AsyncDIContainer;
+use crate::errors::injectable::InjectableError;
+use crate::libs::intertrait::CastFromSync;
+use crate::ptr::TransientPtr;
+
+/// Interface for structs that can be injected into or be injected to.
+#[async_trait]
+pub trait AsyncInjectable: CastFromSync
+{
+ /// Resolves the dependencies of the injectable.
+ ///
+ /// # Errors
+ /// Will return `Err` if resolving the dependencies fails.
+ async fn resolve(
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, InjectableError>
+ where
+ Self: Sized;
+}
+
+impl Debug for dyn AsyncInjectable
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str("{}")
+ }
+}
diff --git a/src/interfaces/mod.rs b/src/interfaces/mod.rs
index 73dde04..ddb3bba 100644
--- a/src/interfaces/mod.rs
+++ b/src/interfaces/mod.rs
@@ -8,3 +8,6 @@ pub mod any_factory;
#[cfg(feature = "factory")]
pub mod factory;
+
+#[cfg(feature = "async")]
+pub mod async_injectable;
diff --git a/src/lib.rs b/src/lib.rs
index 8908143..9fdfa0f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,6 +12,11 @@ pub mod errors;
pub mod interfaces;
pub mod ptr;
+#[cfg(feature = "async")]
+pub mod async_di_container;
+
+#[cfg(feature = "async")]
+pub use async_di_container::AsyncDIContainer;
pub use di_container::DIContainer;
pub use syrette_macros::*;
@@ -75,9 +80,8 @@ macro_rules! di_container_bind {
///
/// A default factory is a factory that doesn't take any arguments.
///
-/// More tedious ways to accomplish what this macro does would either be by using
-/// the [`factory`] macro or by manually declaring the interfaces
-/// with the [`declare_interface`] macro.
+/// The more tedious way to accomplish what this macro does would be by using
+/// the [`factory`] macro.
///
/// *This macro is only available if Syrette is built with the "factory" feature.*
///
@@ -95,43 +99,19 @@ macro_rules! di_container_bind {
///
/// declare_default_factory!(dyn IParser);
/// ```
-///
-/// The expanded equivelent of this would be
-///
-/// ```
-/// # use syrette::declare_default_factory;
-/// #
-/// trait IParser {
-/// // Methods and etc here...
-/// }
-///
-/// syrette::declare_interface!(
-/// syrette::castable_factory::CastableFactory<
-/// (),
-/// dyn IParser,
-/// > -> syrette::interfaces::factory::IFactory<(), dyn IParser>
-/// );
-///
-/// syrette::declare_interface!(
-/// syrette::castable_factory::CastableFactory<
-/// (),
-/// dyn IParser,
-/// > -> syrette::interfaces::any_factory::AnyFactory
-/// );
-/// ```
#[macro_export]
#[cfg(feature = "factory")]
macro_rules! declare_default_factory {
($interface: ty) => {
syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
+ syrette::castable_factory::blocking::CastableFactory<
(),
$interface,
> -> syrette::interfaces::factory::IFactory<(), $interface>
);
syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
+ syrette::castable_factory::blocking::CastableFactory<
(),
$interface,
> -> syrette::interfaces::any_factory::AnyFactory
diff --git a/src/libs/intertrait/mod.rs b/src/libs/intertrait/mod.rs
index 2d62871..bdae4c7 100644
--- a/src/libs/intertrait/mod.rs
+++ b/src/libs/intertrait/mod.rs
@@ -23,7 +23,7 @@
//! MIT license (LICENSE-MIT or <http://opensource.org/licenses/MIT>)
//!
//! at your option.
-use std::any::{Any, TypeId};
+use std::any::{type_name, Any, TypeId};
use std::rc::Rc;
use std::sync::Arc;
@@ -60,7 +60,10 @@ static CASTER_MAP: Lazy<AHashMap<(TypeId, TypeId), BoxedCaster>> = Lazy::new(||
fn cast_arc_panic<Trait: ?Sized + 'static>(_: Arc<dyn Any + Sync + Send>) -> Arc<Trait>
{
- panic!("Prepend [sync] to the list of target traits for Sync + Send types")
+ panic!(
+ "Interface trait '{}' has not been marked async",
+ type_name::<Trait>()
+ )
}
/// A `Caster` knows how to cast a reference to or `Box` of a trait object for `Any`
diff --git a/src/libs/mod.rs b/src/libs/mod.rs
index 8d5583d..b1c7a74 100644
--- a/src/libs/mod.rs
+++ b/src/libs/mod.rs
@@ -1,3 +1,5 @@
pub mod intertrait;
+#[cfg(feature = "async")]
+pub extern crate async_trait;
pub extern crate linkme;
diff --git a/src/provider/async.rs b/src/provider/async.rs
new file mode 100644
index 0000000..93ae03a
--- /dev/null
+++ b/src/provider/async.rs
@@ -0,0 +1,135 @@
+#![allow(clippy::module_name_repetitions)]
+use std::marker::PhantomData;
+
+use async_trait::async_trait;
+
+use crate::async_di_container::AsyncDIContainer;
+use crate::errors::injectable::InjectableError;
+use crate::interfaces::async_injectable::AsyncInjectable;
+use crate::ptr::{ThreadsafeSingletonPtr, TransientPtr};
+
+#[derive(strum_macros::Display, Debug)]
+pub enum AsyncProvidable
+{
+ Transient(TransientPtr<dyn AsyncInjectable>),
+ Singleton(ThreadsafeSingletonPtr<dyn AsyncInjectable>),
+ #[cfg(feature = "factory")]
+ Factory(
+ crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ),
+}
+
+#[async_trait]
+pub trait IAsyncProvider: Send + Sync
+{
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+}
+
+pub struct AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ injectable_phantom: PhantomData<InjectableType>,
+}
+
+impl<InjectableType> AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ pub fn new() -> Self
+ {
+ Self {
+ injectable_phantom: PhantomData,
+ }
+ }
+}
+
+#[async_trait]
+impl<InjectableType> IAsyncProvider for AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Transient(
+ InjectableType::resolve(di_container, dependency_history).await?,
+ ))
+ }
+}
+
+pub struct AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ singleton: ThreadsafeSingletonPtr<InjectableType>,
+}
+
+impl<InjectableType> AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ pub fn new(singleton: ThreadsafeSingletonPtr<InjectableType>) -> Self
+ {
+ Self { singleton }
+ }
+}
+
+#[async_trait]
+impl<InjectableType> IAsyncProvider for AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ async fn provide(
+ &self,
+ _di_container: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Singleton(self.singleton.clone()))
+ }
+}
+
+#[cfg(feature = "factory")]
+pub struct AsyncFactoryProvider
+{
+ factory: crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+}
+
+#[cfg(feature = "factory")]
+impl AsyncFactoryProvider
+{
+ pub fn new(
+ factory: crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ) -> Self
+ {
+ Self { factory }
+ }
+}
+
+#[cfg(feature = "factory")]
+#[async_trait]
+impl IAsyncProvider for AsyncFactoryProvider
+{
+ async fn provide(
+ &self,
+ _di_container: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Factory(self.factory.clone()))
+ }
+}
diff --git a/src/provider.rs b/src/provider/blocking.rs
index 13674b9..13674b9 100644
--- a/src/provider.rs
+++ b/src/provider/blocking.rs
diff --git a/src/provider/mod.rs b/src/provider/mod.rs
new file mode 100644
index 0000000..7fb96bb
--- /dev/null
+++ b/src/provider/mod.rs
@@ -0,0 +1,4 @@
+pub mod blocking;
+
+#[cfg(feature = "async")]
+pub mod r#async;
diff --git a/src/ptr.rs b/src/ptr.rs
index 44fc15c..33f8a95 100644
--- a/src/ptr.rs
+++ b/src/ptr.rs
@@ -2,10 +2,11 @@
//! Smart pointer type aliases.
use std::rc::Rc;
+use std::sync::Arc;
use paste::paste;
-use crate::errors::ptr::SomePtrError;
+use crate::errors::ptr::{SomePtrError, SomeThreadsafePtrError};
/// A smart pointer for a interface in the transient scope.
pub type TransientPtr<Interface> = Box<Interface>;
@@ -13,44 +14,34 @@ pub type TransientPtr<Interface> = Box<Interface>;
/// A smart pointer to a interface in the singleton scope.
pub type SingletonPtr<Interface> = Rc<Interface>;
+/// A threadsafe smart pointer to a interface in the singleton scope.
+pub type ThreadsafeSingletonPtr<Interface> = Arc<Interface>;
+
/// A smart pointer to a factory.
#[cfg(feature = "factory")]
pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>;
-/// Some smart pointer.
-#[derive(strum_macros::IntoStaticStr)]
-pub enum SomePtr<Interface>
-where
- Interface: 'static + ?Sized,
-{
- /// A smart pointer to a interface in the transient scope.
- Transient(TransientPtr<Interface>),
-
- /// A smart pointer to a interface in the singleton scope.
- Singleton(SingletonPtr<Interface>),
-
- /// A smart pointer to a factory.
- #[cfg(feature = "factory")]
- Factory(FactoryPtr<Interface>),
-}
+/// A threadsafe smart pointer to a factory.
+#[cfg(feature = "factory")]
+pub type ThreadsafeFactoryPtr<FactoryInterface> = Arc<FactoryInterface>;
macro_rules! create_as_variant_fn {
- ($variant: ident) => {
+ ($enum: ident, $variant: ident) => {
paste! {
#[doc =
- "Returns as " [<$variant:lower>] ".\n"
+ "Returns as the `" [<$variant>] "` variant.\n"
"\n"
"# Errors\n"
- "Will return Err if it's not a " [<$variant:lower>] "."
+ "Will return Err if it's not the `" [<$variant>] "` variant."
]
- pub fn [<$variant:lower>](self) -> Result<[<$variant Ptr>]<Interface>, SomePtrError>
+ pub fn [<$variant:snake>](self) -> Result<[<$variant Ptr>]<Interface>, [<$enum Error>]>
{
- if let SomePtr::$variant(ptr) = self {
+ if let $enum::$variant(ptr) = self {
return Ok(ptr);
}
- Err(SomePtrError::WrongPtrType {
+ Err([<$enum Error>]::WrongPtrType {
expected: stringify!($variant),
found: self.into()
})
@@ -59,14 +50,60 @@ macro_rules! create_as_variant_fn {
};
}
+/// Some smart pointer.
+#[derive(strum_macros::IntoStaticStr)]
+pub enum SomePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ /// A smart pointer to a interface in the transient scope.
+ Transient(TransientPtr<Interface>),
+
+ /// A smart pointer to a interface in the singleton scope.
+ Singleton(SingletonPtr<Interface>),
+
+ /// A smart pointer to a factory.
+ #[cfg(feature = "factory")]
+ Factory(FactoryPtr<Interface>),
+}
+
impl<Interface> SomePtr<Interface>
where
Interface: 'static + ?Sized,
{
- create_as_variant_fn!(Transient);
+ create_as_variant_fn!(SomePtr, Transient);
+
+ create_as_variant_fn!(SomePtr, Singleton);
+
+ #[cfg(feature = "factory")]
+ create_as_variant_fn!(SomePtr, Factory);
+}
+
+/// Some threadsafe smart pointer.
+#[derive(strum_macros::IntoStaticStr)]
+pub enum SomeThreadsafePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ /// A smart pointer to a interface in the transient scope.
+ Transient(TransientPtr<Interface>),
+
+ /// A smart pointer to a interface in the singleton scope.
+ ThreadsafeSingleton(ThreadsafeSingletonPtr<Interface>),
+
+ /// A smart pointer to a factory.
+ #[cfg(feature = "factory")]
+ ThreadsafeFactory(ThreadsafeFactoryPtr<Interface>),
+}
+
+impl<Interface> SomeThreadsafePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ create_as_variant_fn!(SomeThreadsafePtr, Transient);
- create_as_variant_fn!(Singleton);
+ create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeSingleton);
#[cfg(feature = "factory")]
- create_as_variant_fn!(Factory);
+ create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeFactory);
}