aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml7
-rw-r--r--examples/async/animals/cat.rs22
-rw-r--r--examples/async/animals/dog.rs22
-rw-r--r--examples/async/animals/human.rs36
-rw-r--r--examples/async/animals/mod.rs3
-rw-r--r--examples/async/bootstrap.rs28
-rw-r--r--examples/async/interfaces/cat.rs4
-rw-r--r--examples/async/interfaces/dog.rs4
-rw-r--r--examples/async/interfaces/human.rs4
-rw-r--r--examples/async/interfaces/mod.rs3
-rw-r--r--examples/async/main.rs52
-rw-r--r--macros/Cargo.toml2
-rw-r--r--macros/src/declare_interface_args.rs43
-rw-r--r--macros/src/factory_macro_args.rs44
-rw-r--r--macros/src/injectable_impl.rs102
-rw-r--r--macros/src/injectable_macro_args.rs55
-rw-r--r--macros/src/lib.rs108
-rw-r--r--macros/src/libs/intertrait_macros/gen_caster.rs26
-rw-r--r--macros/src/macro_flag.rs27
-rw-r--r--macros/src/util/mod.rs1
-rw-r--r--macros/src/util/string.rs12
-rw-r--r--src/async_di_container.rs1110
-rw-r--r--src/castable_factory/blocking.rs (renamed from src/castable_factory.rs)0
-rw-r--r--src/castable_factory/mod.rs2
-rw-r--r--src/castable_factory/threadsafe.rs88
-rw-r--r--src/di_container.rs30
-rw-r--r--src/di_container_binding_map.rs38
-rw-r--r--src/errors/async_di_container.rs79
-rw-r--r--src/errors/injectable.rs14
-rw-r--r--src/errors/mod.rs3
-rw-r--r--src/errors/ptr.rs18
-rw-r--r--src/interfaces/any_factory.rs13
-rw-r--r--src/interfaces/async_injectable.rs35
-rw-r--r--src/interfaces/mod.rs3
-rw-r--r--src/lib.rs38
-rw-r--r--src/libs/intertrait/mod.rs7
-rw-r--r--src/libs/mod.rs2
-rw-r--r--src/provider/async.rs135
-rw-r--r--src/provider/blocking.rs (renamed from src/provider.rs)0
-rw-r--r--src/provider/mod.rs4
-rw-r--r--src/ptr.rs89
41 files changed, 2132 insertions, 181 deletions
diff --git a/Cargo.toml b/Cargo.toml
index b3aa027..e9ccb15 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,6 +15,7 @@ all-features = true
default = ["prevent-circular"]
factory = ["syrette_macros/factory"]
prevent-circular = ["syrette_macros/prevent-circular"]
+async = ["async-trait"]
[[example]]
name = "factory"
@@ -24,6 +25,10 @@ required-features = ["factory"]
name = "with-3rd-party"
required-features = ["factory"]
+[[example]]
+name = "async"
+required-features = ["async"]
+
[dependencies]
syrette_macros = { path = "./macros", version = "0.3.0" }
linkme = "0.3.0"
@@ -33,11 +38,13 @@ thiserror = "1.0.32"
strum = "0.24.1"
strum_macros = "0.24.3"
paste = "1.0.8"
+async-trait = { version = "0.1.57", optional = true }
[dev_dependencies]
mockall = "0.11.1"
anyhow = "1.0.62"
third-party-lib = { path = "./examples/with-3rd-party/third-party-lib" }
+tokio = { version = "1.20.1", features = ["full"] }
[workspace]
members = [
diff --git a/examples/async/animals/cat.rs b/examples/async/animals/cat.rs
new file mode 100644
index 0000000..b1e6f27
--- /dev/null
+++ b/examples/async/animals/cat.rs
@@ -0,0 +1,22 @@
+use syrette::injectable;
+
+use crate::interfaces::cat::ICat;
+
+pub struct Cat {}
+
+#[injectable(ICat, { async = true })]
+impl Cat
+{
+ pub fn new() -> Self
+ {
+ Self {}
+ }
+}
+
+impl ICat for Cat
+{
+ fn meow(&self)
+ {
+ println!("Meow!");
+ }
+}
diff --git a/examples/async/animals/dog.rs b/examples/async/animals/dog.rs
new file mode 100644
index 0000000..d1b33f9
--- /dev/null
+++ b/examples/async/animals/dog.rs
@@ -0,0 +1,22 @@
+use syrette::injectable;
+
+use crate::interfaces::dog::IDog;
+
+pub struct Dog {}
+
+#[injectable(IDog, { async = true })]
+impl Dog
+{
+ pub fn new() -> Self
+ {
+ Self {}
+ }
+}
+
+impl IDog for Dog
+{
+ fn woof(&self)
+ {
+ println!("Woof!");
+ }
+}
diff --git a/examples/async/animals/human.rs b/examples/async/animals/human.rs
new file mode 100644
index 0000000..140f27c
--- /dev/null
+++ b/examples/async/animals/human.rs
@@ -0,0 +1,36 @@
+use syrette::injectable;
+use syrette::ptr::{ThreadsafeSingletonPtr, TransientPtr};
+
+use crate::interfaces::cat::ICat;
+use crate::interfaces::dog::IDog;
+use crate::interfaces::human::IHuman;
+
+pub struct Human
+{
+ dog: ThreadsafeSingletonPtr<dyn IDog>,
+ cat: TransientPtr<dyn ICat>,
+}
+
+#[injectable(IHuman, { async = true })]
+impl Human
+{
+ pub fn new(dog: ThreadsafeSingletonPtr<dyn IDog>, cat: TransientPtr<dyn ICat>)
+ -> Self
+ {
+ Self { dog, cat }
+ }
+}
+
+impl IHuman for Human
+{
+ fn make_pets_make_sounds(&self)
+ {
+ println!("Hi doggy!");
+
+ self.dog.woof();
+
+ println!("Hi kitty!");
+
+ self.cat.meow();
+ }
+}
diff --git a/examples/async/animals/mod.rs b/examples/async/animals/mod.rs
new file mode 100644
index 0000000..5444978
--- /dev/null
+++ b/examples/async/animals/mod.rs
@@ -0,0 +1,3 @@
+pub mod cat;
+pub mod dog;
+pub mod human;
diff --git a/examples/async/bootstrap.rs b/examples/async/bootstrap.rs
new file mode 100644
index 0000000..b640712
--- /dev/null
+++ b/examples/async/bootstrap.rs
@@ -0,0 +1,28 @@
+use anyhow::Result;
+use syrette::async_di_container::AsyncDIContainer;
+
+// Concrete implementations
+use crate::animals::cat::Cat;
+use crate::animals::dog::Dog;
+use crate::animals::human::Human;
+//
+// Interfaces
+use crate::interfaces::cat::ICat;
+use crate::interfaces::dog::IDog;
+use crate::interfaces::human::IHuman;
+
+pub async fn bootstrap() -> Result<AsyncDIContainer>
+{
+ let mut di_container = AsyncDIContainer::new();
+
+ di_container
+ .bind::<dyn IDog>()
+ .to::<Dog>()?
+ .in_singleton_scope()
+ .await?;
+
+ di_container.bind::<dyn ICat>().to::<Cat>()?;
+ di_container.bind::<dyn IHuman>().to::<Human>()?;
+
+ Ok(di_container)
+}
diff --git a/examples/async/interfaces/cat.rs b/examples/async/interfaces/cat.rs
new file mode 100644
index 0000000..478f7e0
--- /dev/null
+++ b/examples/async/interfaces/cat.rs
@@ -0,0 +1,4 @@
+pub trait ICat: Send + Sync
+{
+ fn meow(&self);
+}
diff --git a/examples/async/interfaces/dog.rs b/examples/async/interfaces/dog.rs
new file mode 100644
index 0000000..a6ed111
--- /dev/null
+++ b/examples/async/interfaces/dog.rs
@@ -0,0 +1,4 @@
+pub trait IDog: Send + Sync
+{
+ fn woof(&self);
+}
diff --git a/examples/async/interfaces/human.rs b/examples/async/interfaces/human.rs
new file mode 100644
index 0000000..18f9d63
--- /dev/null
+++ b/examples/async/interfaces/human.rs
@@ -0,0 +1,4 @@
+pub trait IHuman: Send + Sync
+{
+ fn make_pets_make_sounds(&self);
+}
diff --git a/examples/async/interfaces/mod.rs b/examples/async/interfaces/mod.rs
new file mode 100644
index 0000000..5444978
--- /dev/null
+++ b/examples/async/interfaces/mod.rs
@@ -0,0 +1,3 @@
+pub mod cat;
+pub mod dog;
+pub mod human;
diff --git a/examples/async/main.rs b/examples/async/main.rs
new file mode 100644
index 0000000..f72ff39
--- /dev/null
+++ b/examples/async/main.rs
@@ -0,0 +1,52 @@
+#![deny(clippy::all)]
+#![deny(clippy::pedantic)]
+#![allow(clippy::module_name_repetitions)]
+
+use std::sync::Arc;
+
+use anyhow::Result;
+use tokio::spawn;
+use tokio::sync::Mutex;
+
+mod animals;
+mod bootstrap;
+mod interfaces;
+
+use bootstrap::bootstrap;
+use interfaces::dog::IDog;
+use interfaces::human::IHuman;
+
+#[tokio::main]
+async fn main() -> Result<()>
+{
+ println!("Hello, world!");
+
+ let di_container = Arc::new(Mutex::new(bootstrap().await?));
+
+ {
+ let dog = di_container
+ .lock()
+ .await
+ .get::<dyn IDog>()
+ .await?
+ .threadsafe_singleton()?;
+
+ dog.woof();
+ }
+
+ spawn(async move {
+ let human = di_container
+ .lock()
+ .await
+ .get::<dyn IHuman>()
+ .await?
+ .transient()?;
+
+ human.make_pets_make_sounds();
+
+ Ok::<_, anyhow::Error>(())
+ })
+ .await??;
+
+ Ok(())
+}
diff --git a/macros/Cargo.toml b/macros/Cargo.toml
index a929b08..28cb4c0 100644
--- a/macros/Cargo.toml
+++ b/macros/Cargo.toml
@@ -23,6 +23,8 @@ syn = { version = "1.0.96", features = ["full"] }
quote = "1.0.18"
proc-macro2 = "1.0.40"
uuid = { version = "0.8", features = ["v4"] }
+regex = "1.6.0"
+once_cell = "1.13.1"
[dev_dependencies]
syrette = { version = "0.3.0", path = "..", features = ["factory"] }
diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs
index b54f458..bd2f24e 100644
--- a/macros/src/declare_interface_args.rs
+++ b/macros/src/declare_interface_args.rs
@@ -1,10 +1,17 @@
use syn::parse::{Parse, ParseStream, Result};
+use syn::punctuated::Punctuated;
use syn::{Path, Token, Type};
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const DECLARE_INTERFACE_FLAGS: &[&str] = &["async"];
+
pub struct DeclareInterfaceArgs
{
pub implementation: Type,
pub interface: Path,
+ pub flags: Punctuated<MacroFlag, Token![,]>,
}
impl Parse for DeclareInterfaceArgs
@@ -15,9 +22,43 @@ impl Parse for DeclareInterfaceArgs
input.parse::<Token![->]>()?;
+ let interface: Path = input.parse()?;
+
+ let flags = if input.peek(Token![,]) {
+ input.parse::<Token![,]>()?;
+
+ let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !DECLARE_INTERFACE_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ DECLARE_INTERFACE_FLAGS.join(",")
+ )));
+ }
+ }
+
+ let flag_names = flags
+ .iter()
+ .map(|flag| flag.flag.to_string())
+ .collect::<Vec<_>>();
+
+ if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+ return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+ }
+
+ flags
+ } else {
+ Punctuated::new()
+ };
+
Ok(Self {
implementation,
- interface: input.parse()?,
+ interface,
+ flags,
})
}
}
diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs
new file mode 100644
index 0000000..57517d6
--- /dev/null
+++ b/macros/src/factory_macro_args.rs
@@ -0,0 +1,44 @@
+use syn::parse::Parse;
+use syn::punctuated::Punctuated;
+use syn::Token;
+
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const FACTORY_MACRO_FLAGS: &[&str] = &["async"];
+
+pub struct FactoryMacroArgs
+{
+ pub flags: Punctuated<MacroFlag, Token![,]>,
+}
+
+impl Parse for FactoryMacroArgs
+{
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self>
+ {
+ let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ FACTORY_MACRO_FLAGS.join(",")
+ )));
+ }
+ }
+
+ let flag_names = flags
+ .iter()
+ .map(|flag| flag.flag.to_string())
+ .collect::<Vec<_>>();
+
+ if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+ return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+ }
+
+ Ok(Self { flags })
+ }
+}
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index 990b148..3565ef9 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type};
use crate::dependency::Dependency;
use crate::util::item_impl::find_impl_method_by_name_mut;
+use crate::util::string::camelcase_to_snakecase;
use crate::util::syn_path::syn_path_to_string;
const DI_CONTAINER_VAR_NAME: &str = "di_container";
@@ -39,7 +40,8 @@ impl Parse for InjectableImpl
impl InjectableImpl
{
- pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream
+ pub fn expand(&self, no_doc_hidden: bool, is_async: bool)
+ -> proc_macro2::TokenStream
{
let Self {
dependencies,
@@ -51,8 +53,6 @@ impl InjectableImpl
let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME);
let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME);
- let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies);
-
let maybe_doc_hidden = if no_doc_hidden {
quote! {}
} else {
@@ -81,36 +81,78 @@ impl InjectableImpl
quote! {}
};
- quote! {
- #original_impl
+ let injectable_impl = if is_async {
+ let async_get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, true);
+
+ quote! {
+ #maybe_doc_hidden
+ #[syrette::libs::async_trait::async_trait]
+ impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type
+ {
+ async fn resolve(
+ #di_container_var: &syrette::async_di_container::AsyncDIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
+
+ use syrette::errors::injectable::InjectableError;
+
+ let self_type_name = type_name::<#self_type>();
+
+ #maybe_prevent_circular_deps
+
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#async_get_dep_method_calls),*
+ )));
+ }
+ }
- #maybe_doc_hidden
- impl #generics syrette::interfaces::injectable::Injectable for #self_type {
- fn resolve(
- #di_container_var: &syrette::DIContainer,
- mut #dependency_history_var: Vec<&'static str>,
- ) -> Result<
- syrette::ptr::TransientPtr<Self>,
- syrette::errors::injectable::InjectableError>
+ }
+ } else {
+ let get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, false);
+
+ quote! {
+ #maybe_doc_hidden
+ impl #generics syrette::interfaces::injectable::Injectable for #self_type
{
- use std::any::type_name;
+ fn resolve(
+ #di_container_var: &syrette::DIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
- use syrette::errors::injectable::InjectableError;
+ use syrette::errors::injectable::InjectableError;
- let self_type_name = type_name::<#self_type>();
+ let self_type_name = type_name::<#self_type>();
- #maybe_prevent_circular_deps
+ #maybe_prevent_circular_deps
- return Ok(syrette::ptr::TransientPtr::new(Self::new(
- #(#get_dep_method_calls),*
- )));
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#get_dep_method_calls),*
+ )));
+ }
}
}
+ };
+
+ quote! {
+ #original_impl
+
+ #injectable_impl
}
}
fn create_get_dep_method_calls(
dependencies: &[Dependency],
+ is_async: bool,
) -> Vec<proc_macro2::TokenStream>
{
dependencies
@@ -146,11 +188,25 @@ impl InjectableImpl
.map(|(method_call, dep_type)| {
let ptr_name = dep_type.ptr.to_string();
- let to_ptr =
- format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase());
+ let to_ptr = format_ident!(
+ "{}",
+ camelcase_to_snakecase(&ptr_name.replace("Ptr", ""))
+ );
+
+ let do_method_call = if is_async {
+ quote! { #method_call.await }
+ } else {
+ quote! { #method_call }
+ };
+
+ let resolve_failed_error = if is_async {
+ quote! { InjectableError::AsyncResolveFailed }
+ } else {
+ quote! { InjectableError::ResolveFailed }
+ };
quote! {
- #method_call.map_err(|err| InjectableError::ResolveFailed {
+ #do_method_call.map_err(|err| #resolve_failed_error {
reason: Box::new(err),
affected: self_type_name
})?.#to_ptr().unwrap()
diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs
index 43f8e11..6cc1d7e 100644
--- a/macros/src/injectable_macro_args.rs
+++ b/macros/src/injectable_macro_args.rs
@@ -1,49 +1,16 @@
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
-use syn::{braced, Ident, LitBool, Token, TypePath};
+use syn::{braced, Token, TypePath};
+use crate::macro_flag::MacroFlag;
use crate::util::iterator_ext::IteratorExt;
-pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden"];
-
-pub struct InjectableMacroFlag
-{
- pub flag: Ident,
- pub is_on: LitBool,
-}
-
-impl Parse for InjectableMacroFlag
-{
- fn parse(input: ParseStream) -> syn::Result<Self>
- {
- let input_forked = input.fork();
-
- let flag: Ident = input_forked.parse()?;
-
- let flag_str = flag.to_string();
-
- if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) {
- return Err(input.error(format!(
- "Unknown flag '{}'. Expected one of [ {} ]",
- flag_str,
- INJECTABLE_MACRO_FLAGS.join(",")
- )));
- }
-
- input.parse::<Ident>()?;
-
- input.parse::<Token![=]>()?;
-
- let is_on: LitBool = input.parse()?;
-
- Ok(Self { flag, is_on })
- }
-}
+pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"];
pub struct InjectableMacroArgs
{
pub interface: Option<TypePath>,
- pub flags: Punctuated<InjectableMacroFlag, Token![,]>,
+ pub flags: Punctuated<MacroFlag, Token![,]>,
}
impl Parse for InjectableMacroArgs
@@ -76,7 +43,19 @@ impl Parse for InjectableMacroArgs
braced!(braced_content in input);
- let flags = braced_content.parse_terminated(InjectableMacroFlag::parse)?;
+ let flags = braced_content.parse_terminated(MacroFlag::parse)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ INJECTABLE_MACRO_FLAGS.join(",")
+ )));
+ }
+ }
let flag_names = flags
.iter()
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index eb3a2be..40fbb53 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -2,7 +2,7 @@
#![deny(clippy::pedantic)]
#![deny(missing_docs)]
-//! Macros for the [Syrette](https://crates.io/crates/syrette) crate.
+//! Macros for the [Sy&rette](https://crates.io/crates/syrette) crate.
use proc_macro::TokenStream;
use quote::quote;
@@ -10,10 +10,12 @@ use syn::{parse, parse_macro_input};
mod declare_interface_args;
mod dependency;
+mod factory_macro_args;
mod factory_type_alias;
mod injectable_impl;
mod injectable_macro_args;
mod libs;
+mod macro_flag;
mod named_attr_input;
mod util;
@@ -31,6 +33,7 @@ use libs::intertrait_macros::gen_caster::generate_caster;
/// # Flags
/// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from
/// documentation.
+/// - `async` - Mark as async.
///
/// # Panics
/// If the attributed item is not a impl.
@@ -107,21 +110,31 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
{
let InjectableMacroArgs { interface, flags } = parse_macro_input!(args_stream);
- let mut flags_iter = flags.iter();
-
- let no_doc_hidden = flags_iter
+ let no_doc_hidden = flags
+ .iter()
.find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden")
.map_or(false, |flag| flag.is_on.value);
+ let is_async = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async")
+ .map_or(false, |flag| flag.is_on.value);
+
let injectable_impl: InjectableImpl = parse(impl_stream).unwrap();
- let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden);
+ let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async);
let maybe_decl_interface = if interface.is_some() {
let self_type = &injectable_impl.self_type;
- quote! {
- syrette::declare_interface!(#self_type -> #interface);
+ if is_async {
+ quote! {
+ syrette::declare_interface!(#self_type -> #interface, async = true);
+ }
+ } else {
+ quote! {
+ syrette::declare_interface!(#self_type -> #interface);
+ }
}
} else {
quote! {}
@@ -139,6 +152,12 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
///
/// *This macro is only available if Syrette is built with the "factory" feature.*
///
+/// # Arguments
+/// * (Zero or more) Flags. Like `a = true, b = false`
+///
+/// # Flags
+/// - `async` - Mark as async.
+///
/// # Panics
/// If the attributed item is not a type alias.
///
@@ -166,8 +185,17 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
/// ```
#[proc_macro_attribute]
#[cfg(feature = "factory")]
-pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
+pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream
{
+ use crate::factory_macro_args::FactoryMacroArgs;
+
+ let FactoryMacroArgs { flags } = parse(args_stream).unwrap();
+
+ let is_async = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async")
+ .map_or(false, |flag| flag.is_on.value);
+
let factory_type_alias::FactoryTypeAlias {
type_alias,
factory_interface,
@@ -175,22 +203,46 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
return_type,
} = parse(type_alias_stream).unwrap();
+ let decl_interfaces = if is_async {
+ quote! {
+ syrette::declare_interface!(
+ syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+ #arg_types,
+ #return_type
+ > -> #factory_interface,
+ async = true
+ );
+
+ syrette::declare_interface!(
+ syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+ #arg_types,
+ #return_type
+ > -> syrette::interfaces::any_factory::AnyThreadsafeFactory,
+ async = true
+ )
+ }
+ } else {
+ quote! {
+ syrette::declare_interface!(
+ syrette::castable_factory::blocking::CastableFactory<
+ #arg_types,
+ #return_type
+ > -> #factory_interface
+ );
+
+ syrette::declare_interface!(
+ syrette::castable_factory::blocking::CastableFactory<
+ #arg_types,
+ #return_type
+ > -> syrette::interfaces::any_factory::AnyFactory
+ );
+ }
+ };
+
quote! {
#type_alias
- syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
- #arg_types,
- #return_type
- > -> #factory_interface
- );
-
- syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
- #arg_types,
- #return_type
- > -> syrette::interfaces::any_factory::AnyFactory
- );
+ #decl_interfaces
}
.into()
}
@@ -199,6 +251,10 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
///
/// # Arguments
/// {Implementation} -> {Interface}
+/// * (Zero or more) Flags. Like `a = true, b = false`
+///
+/// # Flags
+/// - `async` - Mark as async.
///
/// # Examples
/// ```
@@ -218,9 +274,17 @@ pub fn declare_interface(input: TokenStream) -> TokenStream
let DeclareInterfaceArgs {
implementation,
interface,
+ flags,
} = parse_macro_input!(input);
- generate_caster(&implementation, &interface).into()
+ let opt_async_flag = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async");
+
+ let is_async =
+ opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value);
+
+ generate_caster(&implementation, &interface, is_async).into()
}
/// Declares the name of a dependency.
diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs
index 9bac09e..df743e2 100644
--- a/macros/src/libs/intertrait_macros/gen_caster.rs
+++ b/macros/src/libs/intertrait_macros/gen_caster.rs
@@ -22,15 +22,29 @@ const CASTER_FN_NAME_PREFIX: &[u8] = b"__";
const FN_BUF_LEN: usize = CASTER_FN_NAME_PREFIX.len() + Simple::LENGTH;
-pub fn generate_caster(ty: &impl ToTokens, dst_trait: &impl ToTokens) -> TokenStream
+pub fn generate_caster(
+ ty: &impl ToTokens,
+ dst_trait: &impl ToTokens,
+ sync: bool,
+) -> TokenStream
{
let fn_ident = create_caster_fn_ident();
- let new_caster = quote! {
- syrette::libs::intertrait::Caster::<dyn #dst_trait>::new(
- |from| from.downcast::<#ty>().unwrap(),
- |from| from.downcast::<#ty>().unwrap(),
- )
+ let new_caster = if sync {
+ quote! {
+ syrette::libs::intertrait::Caster::<dyn #dst_trait>::new_sync(
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap()
+ )
+ }
+ } else {
+ quote! {
+ syrette::libs::intertrait::Caster::<dyn #dst_trait>::new(
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap(),
+ )
+ }
};
quote! {
diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs
new file mode 100644
index 0000000..257a059
--- /dev/null
+++ b/macros/src/macro_flag.rs
@@ -0,0 +1,27 @@
+use syn::parse::{Parse, ParseStream};
+use syn::{Ident, LitBool, Token};
+
+#[derive(Debug)]
+pub struct MacroFlag
+{
+ pub flag: Ident,
+ pub is_on: LitBool,
+}
+
+impl Parse for MacroFlag
+{
+ fn parse(input: ParseStream) -> syn::Result<Self>
+ {
+ let input_forked = input.fork();
+
+ let flag: Ident = input_forked.parse()?;
+
+ input.parse::<Ident>()?;
+
+ input.parse::<Token![=]>()?;
+
+ let is_on: LitBool = input.parse()?;
+
+ Ok(Self { flag, is_on })
+ }
+}
diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs
index 4f2a594..0705853 100644
--- a/macros/src/util/mod.rs
+++ b/macros/src/util/mod.rs
@@ -1,3 +1,4 @@
pub mod item_impl;
pub mod iterator_ext;
+pub mod string;
pub mod syn_path;
diff --git a/macros/src/util/string.rs b/macros/src/util/string.rs
new file mode 100644
index 0000000..90cccee
--- /dev/null
+++ b/macros/src/util/string.rs
@@ -0,0 +1,12 @@
+use once_cell::sync::Lazy;
+use regex::Regex;
+
+static CAMELCASE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"([a-z])([A-Z])").unwrap());
+
+pub fn camelcase_to_snakecase(camelcased: &str) -> String
+{
+ CAMELCASE_RE
+ .replace(camelcased, "${1}_$2")
+ .to_string()
+ .to_lowercase()
+}
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
new file mode 100644
index 0000000..374746f
--- /dev/null
+++ b/src/async_di_container.rs
@@ -0,0 +1,1110 @@
+//! Asynchronous dependency injection container.
+//!
+//! # Examples
+//! ```
+//! use std::collections::HashMap;
+//! use std::error::Error;
+//!
+//! use syrette::{injectable, AsyncDIContainer};
+//!
+//! trait IDatabaseService
+//! {
+//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>;
+//! }
+//!
+//! struct DatabaseService {}
+//!
+//! #[injectable(IDatabaseService, { async = true })]
+//! impl DatabaseService
+//! {
+//! fn new() -> Self
+//! {
+//! Self {}
+//! }
+//! }
+//!
+//! impl IDatabaseService for DatabaseService
+//! {
+//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>
+//! {
+//! // Do stuff here
+//! HashMap::<String, String>::new()
+//! }
+//! }
+//!
+//! #[tokio::main]
+//! async fn main() -> Result<(), Box<dyn Error>>
+//! {
+//! let mut di_container = AsyncDIContainer::new();
+//!
+//! di_container
+//! .bind::<dyn IDatabaseService>()
+//! .to::<DatabaseService>()?;
+//!
+//! let database_service = di_container
+//! .get::<dyn IDatabaseService>()
+//! .await?
+//! .transient()?;
+//!
+//! Ok(())
+//! }
+//! ```
+//!
+//! ---
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+use std::any::type_name;
+use std::marker::PhantomData;
+
+#[cfg(feature = "factory")]
+use crate::castable_factory::threadsafe::ThreadsafeCastableFactory;
+use crate::di_container_binding_map::DIContainerBindingMap;
+use crate::errors::async_di_container::{
+ AsyncBindingBuilderError,
+ AsyncBindingScopeConfiguratorError,
+ AsyncBindingWhenConfiguratorError,
+ AsyncDIContainerError,
+};
+use crate::interfaces::async_injectable::AsyncInjectable;
+use crate::libs::intertrait::cast::{CastArc, CastBox};
+use crate::provider::r#async::{
+ AsyncProvidable,
+ AsyncSingletonProvider,
+ AsyncTransientTypeProvider,
+ IAsyncProvider,
+};
+use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr};
+
+/// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
+pub struct AsyncBindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+}
+
+impl<'di_container, Interface> AsyncBindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ }
+ }
+
+ /// Configures the binding to have a name.
+ ///
+ /// # Errors
+ /// Will return Err if no binding for the interface already exists.
+ pub fn when_named(
+ &mut self,
+ name: &'static str,
+ ) -> Result<(), AsyncBindingWhenConfiguratorError>
+ {
+ let binding = self
+ .di_container
+ .bindings
+ .remove::<Interface>(None)
+ .map_or_else(
+ || {
+ Err(AsyncBindingWhenConfiguratorError::BindingNotFound(
+ type_name::<Interface>(),
+ ))
+ },
+ Ok,
+ )?;
+
+ self.di_container
+ .bindings
+ .set::<Interface>(Some(name), binding);
+
+ Ok(())
+ }
+}
+
+/// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
+pub struct AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+where
+ Interface: 'static + ?Sized,
+ Implementation: AsyncInjectable,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+ implementation_phantom: PhantomData<Implementation>,
+}
+
+impl<'di_container, Interface, Implementation>
+ AsyncBindingScopeConfigurator<'di_container, Interface, Implementation>
+where
+ Interface: 'static + ?Sized,
+ Implementation: AsyncInjectable,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ implementation_phantom: PhantomData,
+ }
+ }
+
+ /// Configures the binding to be in a transient scope.
+ ///
+ /// This is the default.
+ pub fn in_transient_scope(&mut self) -> AsyncBindingWhenConfigurator<Interface>
+ {
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(AsyncTransientTypeProvider::<Implementation>::new()),
+ );
+
+ AsyncBindingWhenConfigurator::new(self.di_container)
+ }
+
+ /// Configures the binding to be in a singleton scope.
+ ///
+ /// # Errors
+ /// Will return Err if resolving the implementation fails.
+ pub async fn in_singleton_scope(
+ &mut self,
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingScopeConfiguratorError>
+ {
+ let singleton: ThreadsafeSingletonPtr<Implementation> =
+ ThreadsafeSingletonPtr::from(
+ Implementation::resolve(self.di_container, Vec::new())
+ .await
+ .map_err(
+ AsyncBindingScopeConfiguratorError::SingletonResolveFailed,
+ )?,
+ );
+
+ self.di_container
+ .bindings
+ .set::<Interface>(None, Box::new(AsyncSingletonProvider::new(singleton)));
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+}
+
+/// Binding builder for type `Interface` inside a [`AsyncDIContainer`].
+pub struct AsyncBindingBuilder<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ di_container: &'di_container mut AsyncDIContainer,
+ interface_phantom: PhantomData<Interface>,
+}
+
+impl<'di_container, Interface> AsyncBindingBuilder<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ fn new(di_container: &'di_container mut AsyncDIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ }
+ }
+
+ /// Creates a binding of type `Interface` to type `Implementation` inside of the
+ /// associated [`AsyncDIContainer`].
+ ///
+ /// The scope of the binding is transient. But that can be changed by using the
+ /// returned [`AsyncBindingScopeConfigurator`]
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ pub fn to<Implementation>(
+ &mut self,
+ ) -> Result<
+ AsyncBindingScopeConfigurator<Interface, Implementation>,
+ AsyncBindingBuilderError,
+ >
+ where
+ Implementation: AsyncInjectable,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let mut binding_scope_configurator =
+ AsyncBindingScopeConfigurator::new(self.di_container);
+
+ binding_scope_configurator.in_transient_scope();
+
+ Ok(binding_scope_configurator)
+ }
+
+ /// Creates a binding of factory type `Interface` to a factory inside of the
+ /// associated [`AsyncDIContainer`].
+ ///
+ /// *This function is only available if Syrette is built with the "factory" feature.*
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ #[cfg(feature = "factory")]
+ pub fn to_factory<Args, Return>(
+ &mut self,
+ factory_func: &'static (dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync),
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
+ where
+ Args: 'static,
+ Return: 'static + ?Sized,
+ Interface: crate::interfaces::factory::IFactory<Args, Return>,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let factory_impl = ThreadsafeCastableFactory::new(factory_func);
+
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
+ crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ )),
+ );
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+
+ /// Creates a binding of type `Interface` to a factory that takes no arguments
+ /// inside of the associated [`AsyncDIContainer`].
+ ///
+ /// *This function is only available if Syrette is built with the "factory" feature.*
+ ///
+ /// # Errors
+ /// Will return Err if the associated [`AsyncDIContainer`] already have a binding for
+ /// the interface.
+ #[cfg(feature = "factory")]
+ pub fn to_default_factory<Return>(
+ &mut self,
+ factory_func: &'static (dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>
+ + Send
+ + Sync),
+ ) -> Result<AsyncBindingWhenConfigurator<Interface>, AsyncBindingBuilderError>
+ where
+ Return: 'static + ?Sized,
+ {
+ if self.di_container.bindings.has::<Interface>(None) {
+ return Err(AsyncBindingBuilderError::BindingAlreadyExists(type_name::<
+ Interface,
+ >(
+ )));
+ }
+
+ let factory_impl = ThreadsafeCastableFactory::new(factory_func);
+
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::r#async::AsyncFactoryProvider::new(
+ crate::ptr::ThreadsafeFactoryPtr::new(factory_impl),
+ )),
+ );
+
+ Ok(AsyncBindingWhenConfigurator::new(self.di_container))
+ }
+}
+
+/// Dependency injection container.
+pub struct AsyncDIContainer
+{
+ bindings: DIContainerBindingMap<dyn IAsyncProvider>,
+}
+
+impl AsyncDIContainer
+{
+ /// Returns a new `AsyncDIContainer`.
+ #[must_use]
+ pub fn new() -> Self
+ {
+ Self {
+ bindings: DIContainerBindingMap::new(),
+ }
+ }
+
+ /// Returns a new [`AsyncBindingBuilder`] for the given interface.
+ pub fn bind<Interface>(&mut self) -> AsyncBindingBuilder<Interface>
+ where
+ Interface: 'static + ?Sized,
+ {
+ AsyncBindingBuilder::<Interface>::new(self)
+ }
+
+ /// Returns the type bound with `Interface`.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` exists
+ /// - Resolving the binding for fails
+ /// - Casting the binding for fails
+ pub async fn get<Interface>(
+ &self,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.get_bound::<Interface>(Vec::new(), None).await
+ }
+
+ /// Returns the type bound with `Interface` and the specified name.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` with name `name` exists
+ /// - Resolving the binding fails
+ /// - Casting the binding for fails
+ pub async fn get_named<Interface>(
+ &self,
+ name: &'static str,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.get_bound::<Interface>(Vec::new(), Some(name)).await
+ }
+
+ #[doc(hidden)]
+ pub async fn get_bound<Interface>(
+ &self,
+ dependency_history: Vec<&'static str>,
+ name: Option<&'static str>,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ let binding_providable = self
+ .get_binding_providable::<Interface>(name, dependency_history)
+ .await?;
+
+ Self::handle_binding_providable(binding_providable)
+ }
+
+ fn handle_binding_providable<Interface>(
+ binding_providable: AsyncProvidable,
+ ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ match binding_providable {
+ AsyncProvidable::Transient(transient_binding) => {
+ Ok(SomeThreadsafePtr::Transient(
+ transient_binding.cast::<Interface>().map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ ))
+ }
+ AsyncProvidable::Singleton(singleton_binding) => {
+ Ok(SomeThreadsafePtr::ThreadsafeSingleton(
+ singleton_binding.cast::<Interface>().map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ ))
+ }
+ #[cfg(feature = "factory")]
+ AsyncProvidable::Factory(factory_binding) => {
+ match factory_binding.clone().cast::<Interface>() {
+ Ok(factory) => Ok(SomeThreadsafePtr::ThreadsafeFactory(factory)),
+ Err(_err) => {
+ use crate::interfaces::factory::IFactory;
+
+ let default_factory =
+ factory_binding
+ .cast::<dyn IFactory<(), Interface>>()
+ .map_err(|_| {
+ AsyncDIContainerError::CastFailed(type_name::<
+ Interface,
+ >(
+ ))
+ })?;
+
+ Ok(SomeThreadsafePtr::Transient(default_factory()))
+ }
+ }
+ }
+ }
+ }
+
+ async fn get_binding_providable<Interface>(
+ &self,
+ name: Option<&'static str>,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, AsyncDIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.bindings
+ .get::<Interface>(name)
+ .map_or_else(
+ || {
+ Err(AsyncDIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })
+ },
+ Ok,
+ )?
+ .provide(self, dependency_history)
+ .await
+ .map_err(|err| AsyncDIContainerError::BindingResolveFailed {
+ reason: err,
+ interface: type_name::<Interface>(),
+ })
+ }
+}
+
+impl Default for AsyncDIContainer
+{
+ fn default() -> Self
+ {
+ Self::new()
+ }
+}
+
+#[cfg(test)]
+mod tests
+{
+ use std::error::Error;
+
+ use async_trait::async_trait;
+ use mockall::mock;
+
+ use super::*;
+ use crate::errors::injectable::InjectableError;
+ use crate::ptr::TransientPtr;
+
+ mod subjects
+ {
+ //! Test subjects.
+
+ use std::fmt::Debug;
+
+ use async_trait::async_trait;
+ use syrette_macros::declare_interface;
+
+ use super::AsyncDIContainer;
+ use crate::interfaces::async_injectable::AsyncInjectable;
+ use crate::ptr::TransientPtr;
+
+ pub trait IUserManager
+ {
+ fn add_user(&self, user_id: i128);
+
+ fn remove_user(&self, user_id: i128);
+ }
+
+ pub struct UserManager {}
+
+ impl UserManager
+ {
+ pub fn new() -> Self
+ {
+ Self {}
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+
+ fn remove_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+ }
+
+ use crate as syrette;
+
+ declare_interface!(UserManager -> IUserManager);
+
+ #[async_trait]
+ impl AsyncInjectable for UserManager
+ {
+ async fn resolve(
+ _: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
+ where
+ Self: Sized,
+ {
+ Ok(TransientPtr::new(Self::new()))
+ }
+ }
+
+ pub trait INumber
+ {
+ fn get(&self) -> i32;
+
+ fn set(&mut self, number: i32);
+ }
+
+ impl PartialEq for dyn INumber
+ {
+ fn eq(&self, other: &Self) -> bool
+ {
+ self.get() == other.get()
+ }
+ }
+
+ impl Debug for dyn INumber
+ {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str(format!("{}", self.get()).as_str())
+ }
+ }
+
+ pub struct Number
+ {
+ pub num: i32,
+ }
+
+ impl Number
+ {
+ pub fn new() -> Self
+ {
+ Self { num: 0 }
+ }
+ }
+
+ impl INumber for Number
+ {
+ fn get(&self) -> i32
+ {
+ self.num
+ }
+
+ fn set(&mut self, number: i32)
+ {
+ self.num = number;
+ }
+ }
+
+ declare_interface!(Number -> INumber, async = true);
+
+ #[async_trait]
+ impl AsyncInjectable for Number
+ {
+ async fn resolve(
+ _: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, crate::errors::injectable::InjectableError>
+ where
+ Self: Sized,
+ {
+ Ok(TransientPtr::new(Self::new()))
+ }
+ }
+ }
+
+ #[test]
+ fn can_bind_to() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_bind_to_transient() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope();
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope()
+ .when_named("regular")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_bind_to_singleton() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_singleton_scope()
+ .await?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_singleton_scope()
+ .await?
+ .when_named("cool")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "factory")]
+ fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
+ {
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container.bind::<IUserManagerFactory>().to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
+
+ user_manager
+ })?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "factory")]
+ fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
+ {
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<IUserManagerFactory>()
+ .to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
+
+ user_manager
+ })?
+ .when_named("awesome")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider));
+
+ di_container
+ .get::<dyn subjects::IUserManager>()
+ .await?
+ .transient()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider));
+
+ di_container
+ .get_named::<dyn subjects::IUserManager>("special")
+ .await?
+ .transient()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_singleton() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new());
+
+ ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .set::<dyn subjects::INumber>(None, Box::new(mock_provider));
+
+ let first_number_rc = di_container
+ .get::<dyn subjects::INumber>()
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container
+ .get::<dyn subjects::INumber>()
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn can_get_singleton_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = ThreadsafeSingletonPtr::new(subjects::Number::new());
+
+ ThreadsafeSingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_, _| Ok(AsyncProvidable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider));
+
+ let first_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")
+ .await?
+ .threadsafe_singleton()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ #[cfg(feature = "factory")]
+ async fn can_get_factory() -> Result<(), Box<dyn Error>>
+ {
+ trait IUserManager
+ {
+ fn add_user(&mut self, user_id: i128);
+
+ fn remove_user(&mut self, user_id: i128);
+ }
+
+ struct UserManager
+ {
+ users: Vec<i128>,
+ }
+
+ impl UserManager
+ {
+ fn new(users: Vec<i128>) -> Self
+ {
+ Self { users }
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&mut self, user_id: i128)
+ {
+ self.users.push(user_id);
+ }
+
+ fn remove_user(&mut self, user_id: i128)
+ {
+ let user_index =
+ self.users.iter().position(|user| *user == user_id).unwrap();
+
+ self.users.remove(user_index);
+ }
+ }
+
+ use crate as syrette;
+
+ #[crate::factory(async = true)]
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
+ &|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ },
+ )),
+ ))
+ });
+
+ di_container
+ .bindings
+ .set::<IUserManagerFactory>(None, Box::new(mock_provider));
+
+ di_container
+ .get::<IUserManagerFactory>()
+ .await?
+ .threadsafe_factory()?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ #[cfg(feature = "factory")]
+ async fn can_get_factory_named() -> Result<(), Box<dyn Error>>
+ {
+ trait IUserManager
+ {
+ fn add_user(&mut self, user_id: i128);
+
+ fn remove_user(&mut self, user_id: i128);
+ }
+
+ struct UserManager
+ {
+ users: Vec<i128>,
+ }
+
+ impl UserManager
+ {
+ fn new(users: Vec<i128>) -> Self
+ {
+ Self { users }
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&mut self, user_id: i128)
+ {
+ self.users.push(user_id);
+ }
+
+ fn remove_user(&mut self, user_id: i128)
+ {
+ let user_index =
+ self.users.iter().position(|user| *user == user_id).unwrap();
+
+ self.users.remove(user_index);
+ }
+ }
+
+ use crate as syrette;
+
+ #[crate::factory(async = true)]
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+
+ mock! {
+ Provider {}
+
+ #[async_trait]
+ impl IAsyncProvider for Provider
+ {
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+ }
+ }
+
+ let mut di_container: AsyncDIContainer = AsyncDIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(AsyncProvidable::Factory(
+ crate::ptr::ThreadsafeFactoryPtr::new(ThreadsafeCastableFactory::new(
+ &|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ },
+ )),
+ ))
+ });
+
+ di_container
+ .bindings
+ .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+
+ di_container
+ .get_named::<IUserManagerFactory>("special")
+ .await?
+ .threadsafe_factory()?;
+
+ Ok(())
+ }
+}
diff --git a/src/castable_factory.rs b/src/castable_factory/blocking.rs
index 5ff4db0..5ff4db0 100644
--- a/src/castable_factory.rs
+++ b/src/castable_factory/blocking.rs
diff --git a/src/castable_factory/mod.rs b/src/castable_factory/mod.rs
new file mode 100644
index 0000000..530cc82
--- /dev/null
+++ b/src/castable_factory/mod.rs
@@ -0,0 +1,2 @@
+pub mod blocking;
+pub mod threadsafe;
diff --git a/src/castable_factory/threadsafe.rs b/src/castable_factory/threadsafe.rs
new file mode 100644
index 0000000..7be055c
--- /dev/null
+++ b/src/castable_factory/threadsafe.rs
@@ -0,0 +1,88 @@
+#![allow(clippy::module_name_repetitions)]
+use crate::interfaces::any_factory::{AnyFactory, AnyThreadsafeFactory};
+use crate::interfaces::factory::IFactory;
+use crate::ptr::TransientPtr;
+
+pub struct ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync),
+}
+
+impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ pub fn new(
+ func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>>
+ + Send
+ + Sync),
+ ) -> Self
+ {
+ Self { func }
+ }
+}
+
+impl<Args, ReturnInterface> IFactory<Args, ReturnInterface>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
+
+impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ extern "rust-call" fn call(&self, args: Args) -> Self::Output
+ {
+ self.func.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> FnMut<Args>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output
+ {
+ self.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> FnOnce<Args>
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+ type Output = TransientPtr<ReturnInterface>;
+
+ extern "rust-call" fn call_once(self, args: Args) -> Self::Output
+ {
+ self.call(args)
+ }
+}
+
+impl<Args, ReturnInterface> AnyFactory
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
+
+impl<Args, ReturnInterface> AnyThreadsafeFactory
+ for ThreadsafeCastableFactory<Args, ReturnInterface>
+where
+ Args: 'static,
+ ReturnInterface: 'static + ?Sized,
+{
+}
diff --git a/src/di_container.rs b/src/di_container.rs
index e42175b..b0e5af1 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -1,4 +1,4 @@
-//! Dependency injection container and other related utilities.
+//! Dependency injection container.
//!
//! # Examples
//! ```
@@ -53,7 +53,7 @@ use std::any::type_name;
use std::marker::PhantomData;
#[cfg(feature = "factory")]
-use crate::castable_factory::CastableFactory;
+use crate::castable_factory::blocking::CastableFactory;
use crate::di_container_binding_map::DIContainerBindingMap;
use crate::errors::di_container::{
BindingBuilderError,
@@ -63,7 +63,12 @@ use crate::errors::di_container::{
};
use crate::interfaces::injectable::Injectable;
use crate::libs::intertrait::cast::{CastBox, CastRc};
-use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};
+use crate::provider::blocking::{
+ IProvider,
+ Providable,
+ SingletonProvider,
+ TransientTypeProvider,
+};
use crate::ptr::{SingletonPtr, SomePtr};
/// When configurator for a binding for type 'Interface' inside a [`DIContainer`].
@@ -256,7 +261,7 @@ where
self.di_container.bindings.set::<Interface>(
None,
- Box::new(crate::provider::FactoryProvider::new(
+ Box::new(crate::provider::blocking::FactoryProvider::new(
crate::ptr::FactoryPtr::new(factory_impl),
)),
);
@@ -290,7 +295,7 @@ where
self.di_container.bindings.set::<Interface>(
None,
- Box::new(crate::provider::FactoryProvider::new(
+ Box::new(crate::provider::blocking::FactoryProvider::new(
crate::ptr::FactoryPtr::new(factory_impl),
)),
);
@@ -302,7 +307,7 @@ where
/// Dependency injection container.
pub struct DIContainer
{
- bindings: DIContainerBindingMap,
+ bindings: DIContainerBindingMap<dyn IProvider>,
}
impl DIContainer
@@ -416,7 +421,16 @@ impl DIContainer
Interface: 'static + ?Sized,
{
self.bindings
- .get::<Interface>(name)?
+ .get::<Interface>(name)
+ .map_or_else(
+ || {
+ Err(DIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })
+ },
+ Ok,
+ )?
.provide(self, dependency_history)
.map_err(|err| DIContainerError::BindingResolveFailed {
reason: err,
@@ -442,7 +456,7 @@ mod tests
use super::*;
use crate::errors::injectable::InjectableError;
- use crate::provider::IProvider;
+ use crate::provider::blocking::IProvider;
use crate::ptr::TransientPtr;
mod subjects
diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs
index 4df889d..4aa246e 100644
--- a/src/di_container_binding_map.rs
+++ b/src/di_container_binding_map.rs
@@ -1,10 +1,7 @@
-use std::any::{type_name, TypeId};
+use std::any::TypeId;
use ahash::AHashMap;
-use crate::errors::di_container::DIContainerError;
-use crate::provider::IProvider;
-
#[derive(Debug, PartialEq, Eq, Hash)]
struct DIContainerBindingKey
{
@@ -12,12 +9,16 @@ struct DIContainerBindingKey
name: Option<&'static str>,
}
-pub struct DIContainerBindingMap
+pub struct DIContainerBindingMap<Provider>
+where
+ Provider: 'static + ?Sized,
{
- bindings: AHashMap<DIContainerBindingKey, Box<dyn IProvider>>,
+ bindings: AHashMap<DIContainerBindingKey, Box<Provider>>,
}
-impl DIContainerBindingMap
+impl<Provider> DIContainerBindingMap<Provider>
+where
+ Provider: 'static + ?Sized,
{
pub fn new() -> Self
{
@@ -26,33 +27,22 @@ impl DIContainerBindingMap
}
}
- pub fn get<Interface>(
- &self,
- name: Option<&'static str>,
- ) -> Result<&dyn IProvider, DIContainerError>
+ pub fn get<Interface>(&self, name: Option<&'static str>) -> Option<&Provider>
where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
- Ok(self
- .bindings
+ self.bindings
.get(&DIContainerBindingKey {
type_id: interface_typeid,
name,
})
- .ok_or_else(|| DIContainerError::BindingNotFound {
- interface: type_name::<Interface>(),
- name,
- })?
- .as_ref())
+ .map(|provider| provider.as_ref())
}
- pub fn set<Interface>(
- &mut self,
- name: Option<&'static str>,
- provider: Box<dyn IProvider>,
- ) where
+ pub fn set<Interface>(&mut self, name: Option<&'static str>, provider: Box<Provider>)
+ where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
@@ -69,7 +59,7 @@ impl DIContainerBindingMap
pub fn remove<Interface>(
&mut self,
name: Option<&'static str>,
- ) -> Option<Box<dyn IProvider>>
+ ) -> Option<Box<Provider>>
where
Interface: 'static + ?Sized,
{
diff --git a/src/errors/async_di_container.rs b/src/errors/async_di_container.rs
new file mode 100644
index 0000000..4f5e50a
--- /dev/null
+++ b/src/errors/async_di_container.rs
@@ -0,0 +1,79 @@
+//! Error types for [`AsyncDIContainer`] and it's related structs.
+//!
+//! ---
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+//!
+//! [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer
+
+use crate::errors::injectable::InjectableError;
+
+/// Error type for [`AsyncDIContainer`].
+///
+/// [`AsyncDIContainer`]: crate::async_di_container::AsyncDIContainer
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncDIContainerError
+{
+ /// Unable to cast a binding for a interface.
+ #[error("Unable to cast binding for interface '{0}'")]
+ CastFailed(&'static str),
+
+ /// Failed to resolve a binding for a interface.
+ #[error("Failed to resolve binding for interface '{interface}'")]
+ BindingResolveFailed
+ {
+ /// The reason for the problem.
+ #[source]
+ reason: InjectableError,
+
+ /// The affected bound interface.
+ interface: &'static str,
+ },
+
+ /// No binding exists for a interface (and optionally a name).
+ #[error(
+ "No binding exists for interface '{interface}' {}",
+ .name.map_or_else(String::new, |name| format!("with name '{}'", name))
+ )]
+ BindingNotFound
+ {
+ /// The interface that doesn't have a binding.
+ interface: &'static str,
+
+ /// The name of the binding if one exists.
+ name: Option<&'static str>,
+ },
+}
+
+/// Error type for [`AsyncBindingBuilder`].
+///
+/// [`AsyncBindingBuilder`]: crate::async_di_container::AsyncBindingBuilder
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingBuilderError
+{
+ /// A binding already exists for a interface.
+ #[error("Binding already exists for interface '{0}'")]
+ BindingAlreadyExists(&'static str),
+}
+
+/// Error type for [`AsyncBindingScopeConfigurator`].
+///
+/// [`AsyncBindingScopeConfigurator`]: crate::async_di_container::AsyncBindingScopeConfigurator
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingScopeConfiguratorError
+{
+ /// Resolving a singleton failed.
+ #[error("Resolving the given singleton failed")]
+ SingletonResolveFailed(#[from] InjectableError),
+}
+
+/// Error type for [`AsyncBindingWhenConfigurator`].
+///
+/// [`AsyncBindingWhenConfigurator`]: crate::async_di_container::AsyncBindingWhenConfigurator
+#[derive(thiserror::Error, Debug)]
+pub enum AsyncBindingWhenConfiguratorError
+{
+ /// A binding for a interface wasn't found.
+ #[error("A binding for interface '{0}' wasn't found'")]
+ BindingNotFound(&'static str),
+}
diff --git a/src/errors/injectable.rs b/src/errors/injectable.rs
index 4b9af96..ed161cb 100644
--- a/src/errors/injectable.rs
+++ b/src/errors/injectable.rs
@@ -3,7 +3,7 @@
//!
//! [`Injectable`]: crate::interfaces::injectable::Injectable
-use super::di_container::DIContainerError;
+use crate::errors::di_container::DIContainerError;
/// Error type for structs that implement [`Injectable`].
///
@@ -23,6 +23,18 @@ pub enum InjectableError
affected: &'static str,
},
+ /// Failed to resolve dependencies.
+ #[cfg(feature = "async")]
+ #[error("Failed to resolve a dependency of '{affected}'")]
+ AsyncResolveFailed
+ {
+ /// The reason for the problem.
+ #[source]
+ reason: Box<crate::errors::async_di_container::AsyncDIContainerError>,
+
+ /// The affected injectable type.
+ affected: &'static str,
+ },
/// Detected circular dependencies.
#[error("Detected circular dependencies. {dependency_trace}")]
DetectedCircular
diff --git a/src/errors/mod.rs b/src/errors/mod.rs
index 7d66ddf..c3930b0 100644
--- a/src/errors/mod.rs
+++ b/src/errors/mod.rs
@@ -3,3 +3,6 @@
pub mod di_container;
pub mod injectable;
pub mod ptr;
+
+#[cfg(feature = "async")]
+pub mod async_di_container;
diff --git a/src/errors/ptr.rs b/src/errors/ptr.rs
index e0c3d05..56621c1 100644
--- a/src/errors/ptr.rs
+++ b/src/errors/ptr.rs
@@ -17,3 +17,21 @@ pub enum SomePtrError
found: &'static str,
},
}
+
+/// Error type for [`SomeThreadsafePtr`].
+///
+/// [`SomeThreadsafePtr`]: crate::ptr::SomeThreadsafePtr
+#[derive(thiserror::Error, Debug)]
+pub enum SomeThreadsafePtrError
+{
+ /// Tried to get as a wrong threadsafe smart pointer type.
+ #[error("Wrong threadsafe smart pointer type. Expected {expected}, found {found}")]
+ WrongPtrType
+ {
+ /// The expected smart pointer type.
+ expected: &'static str,
+
+ /// The found smart pointer type.
+ found: &'static str,
+ },
+}
diff --git a/src/interfaces/any_factory.rs b/src/interfaces/any_factory.rs
index 887bb61..1bf9208 100644
--- a/src/interfaces/any_factory.rs
+++ b/src/interfaces/any_factory.rs
@@ -2,7 +2,7 @@
use std::fmt::Debug;
-use crate::libs::intertrait::CastFrom;
+use crate::libs::intertrait::{CastFrom, CastFromSync};
/// Interface for any factory to ever exist.
pub trait AnyFactory: CastFrom {}
@@ -14,3 +14,14 @@ impl Debug for dyn AnyFactory
f.write_str("{}")
}
}
+
+/// Interface for any threadsafe factory to ever exist.
+pub trait AnyThreadsafeFactory: CastFromSync {}
+
+impl Debug for dyn AnyThreadsafeFactory
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str("{}")
+ }
+}
diff --git a/src/interfaces/async_injectable.rs b/src/interfaces/async_injectable.rs
new file mode 100644
index 0000000..badc3c5
--- /dev/null
+++ b/src/interfaces/async_injectable.rs
@@ -0,0 +1,35 @@
+//! Interface for structs that can be injected into or be injected to.
+//!
+//! *This module is only available if Syrette is built with the "async" feature.*
+use std::fmt::Debug;
+
+use async_trait::async_trait;
+
+use crate::async_di_container::AsyncDIContainer;
+use crate::errors::injectable::InjectableError;
+use crate::libs::intertrait::CastFromSync;
+use crate::ptr::TransientPtr;
+
+/// Interface for structs that can be injected into or be injected to.
+#[async_trait]
+pub trait AsyncInjectable: CastFromSync
+{
+ /// Resolves the dependencies of the injectable.
+ ///
+ /// # Errors
+ /// Will return `Err` if resolving the dependencies fails.
+ async fn resolve(
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<TransientPtr<Self>, InjectableError>
+ where
+ Self: Sized;
+}
+
+impl Debug for dyn AsyncInjectable
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str("{}")
+ }
+}
diff --git a/src/interfaces/mod.rs b/src/interfaces/mod.rs
index 73dde04..ddb3bba 100644
--- a/src/interfaces/mod.rs
+++ b/src/interfaces/mod.rs
@@ -8,3 +8,6 @@ pub mod any_factory;
#[cfg(feature = "factory")]
pub mod factory;
+
+#[cfg(feature = "async")]
+pub mod async_injectable;
diff --git a/src/lib.rs b/src/lib.rs
index 8908143..9fdfa0f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,6 +12,11 @@ pub mod errors;
pub mod interfaces;
pub mod ptr;
+#[cfg(feature = "async")]
+pub mod async_di_container;
+
+#[cfg(feature = "async")]
+pub use async_di_container::AsyncDIContainer;
pub use di_container::DIContainer;
pub use syrette_macros::*;
@@ -75,9 +80,8 @@ macro_rules! di_container_bind {
///
/// A default factory is a factory that doesn't take any arguments.
///
-/// More tedious ways to accomplish what this macro does would either be by using
-/// the [`factory`] macro or by manually declaring the interfaces
-/// with the [`declare_interface`] macro.
+/// The more tedious way to accomplish what this macro does would be by using
+/// the [`factory`] macro.
///
/// *This macro is only available if Syrette is built with the "factory" feature.*
///
@@ -95,43 +99,19 @@ macro_rules! di_container_bind {
///
/// declare_default_factory!(dyn IParser);
/// ```
-///
-/// The expanded equivelent of this would be
-///
-/// ```
-/// # use syrette::declare_default_factory;
-/// #
-/// trait IParser {
-/// // Methods and etc here...
-/// }
-///
-/// syrette::declare_interface!(
-/// syrette::castable_factory::CastableFactory<
-/// (),
-/// dyn IParser,
-/// > -> syrette::interfaces::factory::IFactory<(), dyn IParser>
-/// );
-///
-/// syrette::declare_interface!(
-/// syrette::castable_factory::CastableFactory<
-/// (),
-/// dyn IParser,
-/// > -> syrette::interfaces::any_factory::AnyFactory
-/// );
-/// ```
#[macro_export]
#[cfg(feature = "factory")]
macro_rules! declare_default_factory {
($interface: ty) => {
syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
+ syrette::castable_factory::blocking::CastableFactory<
(),
$interface,
> -> syrette::interfaces::factory::IFactory<(), $interface>
);
syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
+ syrette::castable_factory::blocking::CastableFactory<
(),
$interface,
> -> syrette::interfaces::any_factory::AnyFactory
diff --git a/src/libs/intertrait/mod.rs b/src/libs/intertrait/mod.rs
index 2d62871..bdae4c7 100644
--- a/src/libs/intertrait/mod.rs
+++ b/src/libs/intertrait/mod.rs
@@ -23,7 +23,7 @@
//! MIT license (LICENSE-MIT or <http://opensource.org/licenses/MIT>)
//!
//! at your option.
-use std::any::{Any, TypeId};
+use std::any::{type_name, Any, TypeId};
use std::rc::Rc;
use std::sync::Arc;
@@ -60,7 +60,10 @@ static CASTER_MAP: Lazy<AHashMap<(TypeId, TypeId), BoxedCaster>> = Lazy::new(||
fn cast_arc_panic<Trait: ?Sized + 'static>(_: Arc<dyn Any + Sync + Send>) -> Arc<Trait>
{
- panic!("Prepend [sync] to the list of target traits for Sync + Send types")
+ panic!(
+ "Interface trait '{}' has not been marked async",
+ type_name::<Trait>()
+ )
}
/// A `Caster` knows how to cast a reference to or `Box` of a trait object for `Any`
diff --git a/src/libs/mod.rs b/src/libs/mod.rs
index 8d5583d..b1c7a74 100644
--- a/src/libs/mod.rs
+++ b/src/libs/mod.rs
@@ -1,3 +1,5 @@
pub mod intertrait;
+#[cfg(feature = "async")]
+pub extern crate async_trait;
pub extern crate linkme;
diff --git a/src/provider/async.rs b/src/provider/async.rs
new file mode 100644
index 0000000..93ae03a
--- /dev/null
+++ b/src/provider/async.rs
@@ -0,0 +1,135 @@
+#![allow(clippy::module_name_repetitions)]
+use std::marker::PhantomData;
+
+use async_trait::async_trait;
+
+use crate::async_di_container::AsyncDIContainer;
+use crate::errors::injectable::InjectableError;
+use crate::interfaces::async_injectable::AsyncInjectable;
+use crate::ptr::{ThreadsafeSingletonPtr, TransientPtr};
+
+#[derive(strum_macros::Display, Debug)]
+pub enum AsyncProvidable
+{
+ Transient(TransientPtr<dyn AsyncInjectable>),
+ Singleton(ThreadsafeSingletonPtr<dyn AsyncInjectable>),
+ #[cfg(feature = "factory")]
+ Factory(
+ crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ),
+}
+
+#[async_trait]
+pub trait IAsyncProvider: Send + Sync
+{
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>;
+}
+
+pub struct AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ injectable_phantom: PhantomData<InjectableType>,
+}
+
+impl<InjectableType> AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ pub fn new() -> Self
+ {
+ Self {
+ injectable_phantom: PhantomData,
+ }
+ }
+}
+
+#[async_trait]
+impl<InjectableType> IAsyncProvider for AsyncTransientTypeProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ async fn provide(
+ &self,
+ di_container: &AsyncDIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Transient(
+ InjectableType::resolve(di_container, dependency_history).await?,
+ ))
+ }
+}
+
+pub struct AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ singleton: ThreadsafeSingletonPtr<InjectableType>,
+}
+
+impl<InjectableType> AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ pub fn new(singleton: ThreadsafeSingletonPtr<InjectableType>) -> Self
+ {
+ Self { singleton }
+ }
+}
+
+#[async_trait]
+impl<InjectableType> IAsyncProvider for AsyncSingletonProvider<InjectableType>
+where
+ InjectableType: AsyncInjectable,
+{
+ async fn provide(
+ &self,
+ _di_container: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Singleton(self.singleton.clone()))
+ }
+}
+
+#[cfg(feature = "factory")]
+pub struct AsyncFactoryProvider
+{
+ factory: crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+}
+
+#[cfg(feature = "factory")]
+impl AsyncFactoryProvider
+{
+ pub fn new(
+ factory: crate::ptr::ThreadsafeFactoryPtr<
+ dyn crate::interfaces::any_factory::AnyThreadsafeFactory,
+ >,
+ ) -> Self
+ {
+ Self { factory }
+ }
+}
+
+#[cfg(feature = "factory")]
+#[async_trait]
+impl IAsyncProvider for AsyncFactoryProvider
+{
+ async fn provide(
+ &self,
+ _di_container: &AsyncDIContainer,
+ _dependency_history: Vec<&'static str>,
+ ) -> Result<AsyncProvidable, InjectableError>
+ {
+ Ok(AsyncProvidable::Factory(self.factory.clone()))
+ }
+}
diff --git a/src/provider.rs b/src/provider/blocking.rs
index 13674b9..13674b9 100644
--- a/src/provider.rs
+++ b/src/provider/blocking.rs
diff --git a/src/provider/mod.rs b/src/provider/mod.rs
new file mode 100644
index 0000000..7fb96bb
--- /dev/null
+++ b/src/provider/mod.rs
@@ -0,0 +1,4 @@
+pub mod blocking;
+
+#[cfg(feature = "async")]
+pub mod r#async;
diff --git a/src/ptr.rs b/src/ptr.rs
index 44fc15c..33f8a95 100644
--- a/src/ptr.rs
+++ b/src/ptr.rs
@@ -2,10 +2,11 @@
//! Smart pointer type aliases.
use std::rc::Rc;
+use std::sync::Arc;
use paste::paste;
-use crate::errors::ptr::SomePtrError;
+use crate::errors::ptr::{SomePtrError, SomeThreadsafePtrError};
/// A smart pointer for a interface in the transient scope.
pub type TransientPtr<Interface> = Box<Interface>;
@@ -13,44 +14,34 @@ pub type TransientPtr<Interface> = Box<Interface>;
/// A smart pointer to a interface in the singleton scope.
pub type SingletonPtr<Interface> = Rc<Interface>;
+/// A threadsafe smart pointer to a interface in the singleton scope.
+pub type ThreadsafeSingletonPtr<Interface> = Arc<Interface>;
+
/// A smart pointer to a factory.
#[cfg(feature = "factory")]
pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>;
-/// Some smart pointer.
-#[derive(strum_macros::IntoStaticStr)]
-pub enum SomePtr<Interface>
-where
- Interface: 'static + ?Sized,
-{
- /// A smart pointer to a interface in the transient scope.
- Transient(TransientPtr<Interface>),
-
- /// A smart pointer to a interface in the singleton scope.
- Singleton(SingletonPtr<Interface>),
-
- /// A smart pointer to a factory.
- #[cfg(feature = "factory")]
- Factory(FactoryPtr<Interface>),
-}
+/// A threadsafe smart pointer to a factory.
+#[cfg(feature = "factory")]
+pub type ThreadsafeFactoryPtr<FactoryInterface> = Arc<FactoryInterface>;
macro_rules! create_as_variant_fn {
- ($variant: ident) => {
+ ($enum: ident, $variant: ident) => {
paste! {
#[doc =
- "Returns as " [<$variant:lower>] ".\n"
+ "Returns as the `" [<$variant>] "` variant.\n"
"\n"
"# Errors\n"
- "Will return Err if it's not a " [<$variant:lower>] "."
+ "Will return Err if it's not the `" [<$variant>] "` variant."
]
- pub fn [<$variant:lower>](self) -> Result<[<$variant Ptr>]<Interface>, SomePtrError>
+ pub fn [<$variant:snake>](self) -> Result<[<$variant Ptr>]<Interface>, [<$enum Error>]>
{
- if let SomePtr::$variant(ptr) = self {
+ if let $enum::$variant(ptr) = self {
return Ok(ptr);
}
- Err(SomePtrError::WrongPtrType {
+ Err([<$enum Error>]::WrongPtrType {
expected: stringify!($variant),
found: self.into()
})
@@ -59,14 +50,60 @@ macro_rules! create_as_variant_fn {
};
}
+/// Some smart pointer.
+#[derive(strum_macros::IntoStaticStr)]
+pub enum SomePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ /// A smart pointer to a interface in the transient scope.
+ Transient(TransientPtr<Interface>),
+
+ /// A smart pointer to a interface in the singleton scope.
+ Singleton(SingletonPtr<Interface>),
+
+ /// A smart pointer to a factory.
+ #[cfg(feature = "factory")]
+ Factory(FactoryPtr<Interface>),
+}
+
impl<Interface> SomePtr<Interface>
where
Interface: 'static + ?Sized,
{
- create_as_variant_fn!(Transient);
+ create_as_variant_fn!(SomePtr, Transient);
+
+ create_as_variant_fn!(SomePtr, Singleton);
+
+ #[cfg(feature = "factory")]
+ create_as_variant_fn!(SomePtr, Factory);
+}
+
+/// Some threadsafe smart pointer.
+#[derive(strum_macros::IntoStaticStr)]
+pub enum SomeThreadsafePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ /// A smart pointer to a interface in the transient scope.
+ Transient(TransientPtr<Interface>),
+
+ /// A smart pointer to a interface in the singleton scope.
+ ThreadsafeSingleton(ThreadsafeSingletonPtr<Interface>),
+
+ /// A smart pointer to a factory.
+ #[cfg(feature = "factory")]
+ ThreadsafeFactory(ThreadsafeFactoryPtr<Interface>),
+}
+
+impl<Interface> SomeThreadsafePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ create_as_variant_fn!(SomeThreadsafePtr, Transient);
- create_as_variant_fn!(Singleton);
+ create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeSingleton);
#[cfg(feature = "factory")]
- create_as_variant_fn!(Factory);
+ create_as_variant_fn!(SomeThreadsafePtr, ThreadsafeFactory);
}