aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-09-23 22:19:08 +0200
committerHampusM <hampus@hampusmat.com>2022-09-23 22:19:08 +0200
commit3ed020425bfd1fc5fedfa89a7ce20207bedcf5bc (patch)
tree7ac971df9005f82445a4d01e4c5dec2c938a63d9
parent145a257775f2397ceba0941ac2a2642cf3382dcb (diff)
fix: prevent problems caused by non send + sync traits
-rw-r--r--macros/src/fn_trait.rs20
-rw-r--r--macros/src/lib.rs7
-rw-r--r--src/async_di_container.rs64
3 files changed, 62 insertions, 29 deletions
diff --git a/macros/src/fn_trait.rs b/macros/src/fn_trait.rs
index f9b3514..9820f02 100644
--- a/macros/src/fn_trait.rs
+++ b/macros/src/fn_trait.rs
@@ -2,7 +2,7 @@ use quote::ToTokens;
use syn::parse::Parse;
use syn::punctuated::Punctuated;
use syn::token::Paren;
-use syn::{parenthesized, Ident, Token, Type};
+use syn::{parenthesized, parse_str, Ident, Token, TraitBound, Type};
/// A function trait. `dyn Fn(u32) -> String`
#[derive(Debug, Clone)]
@@ -14,6 +14,15 @@ pub struct FnTrait
pub inputs: Punctuated<Type, Token![,]>,
pub r_arrow_token: Token![->],
pub output: Type,
+ pub trait_bounds: Punctuated<TraitBound, Token![+]>,
+}
+
+impl FnTrait
+{
+ pub fn add_trait_bound(&mut self, trait_bound: TraitBound)
+ {
+ self.trait_bounds.push(trait_bound);
+ }
}
impl Parse for FnTrait
@@ -45,6 +54,7 @@ impl Parse for FnTrait
inputs,
r_arrow_token,
output,
+ trait_bounds: Punctuated::new(),
})
}
}
@@ -64,5 +74,13 @@ impl ToTokens for FnTrait
self.r_arrow_token.to_tokens(tokens);
self.output.to_tokens(tokens);
+
+ if !self.trait_bounds.is_empty() {
+ let plus: Token![+] = parse_str("+").unwrap();
+
+ plus.to_tokens(tokens);
+
+ self.trait_bounds.to_tokens(tokens);
+ }
}
}
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 2cd57f0..b0ccc86 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -6,7 +6,7 @@
use proc_macro::TokenStream;
use quote::quote;
-use syn::{parse, parse_macro_input};
+use syn::{parse, parse_macro_input, parse_str};
mod decl_def_factory_args;
mod declare_interface_args;
@@ -236,6 +236,11 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
)
.unwrap();
+ if is_threadsafe {
+ factory_interface.add_trait_bound(parse_str("Send").unwrap());
+ factory_interface.add_trait_bound(parse_str("Sync").unwrap());
+ }
+
type_alias.ty = Box::new(Type::Verbatim(factory_interface.to_token_stream()));
let decl_interfaces = if is_threadsafe {
diff --git a/src/async_di_container.rs b/src/async_di_container.rs
index 7e01c66..c67900e 100644
--- a/src/async_di_container.rs
+++ b/src/async_di_container.rs
@@ -7,7 +7,7 @@
//!
//! use syrette::{injectable, AsyncDIContainer};
//!
-//! trait IDatabaseService
+//! trait IDatabaseService: Send + Sync
//! {
//! fn get_all_records(&self, table_name: String) -> HashMap<String, String>;
//! }
@@ -83,7 +83,7 @@ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr};
/// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
pub struct AsyncBindingWhenConfigurator<Interface>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
di_container: Arc<AsyncDIContainer>,
interface_phantom: PhantomData<Interface>,
@@ -91,7 +91,7 @@ where
impl<Interface> AsyncBindingWhenConfigurator<Interface>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
fn new(di_container: Arc<AsyncDIContainer>) -> Self
{
@@ -130,7 +130,7 @@ where
/// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`].
pub struct AsyncBindingScopeConfigurator<Interface, Implementation>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
Implementation: AsyncInjectable,
{
di_container: Arc<AsyncDIContainer>,
@@ -140,7 +140,7 @@ where
impl<Interface, Implementation> AsyncBindingScopeConfigurator<Interface, Implementation>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
Implementation: AsyncInjectable,
{
fn new(di_container: Arc<AsyncDIContainer>) -> Self
@@ -196,7 +196,7 @@ where
/// Binding builder for type `Interface` inside a [`AsyncDIContainer`].
pub struct AsyncBindingBuilder<Interface>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
di_container: Arc<AsyncDIContainer>,
interface_phantom: PhantomData<Interface>,
@@ -204,7 +204,7 @@ where
impl<Interface> AsyncBindingBuilder<Interface>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
fn new(di_container: Arc<AsyncDIContainer>) -> Self
{
@@ -267,9 +267,11 @@ where
where
Args: 'static,
Return: 'static + ?Sized,
- Interface: Fn<Args, Output = Return>,
- FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = Box<(dyn Fn<Args, Output = Return>)>>
- + Send
+ Interface: Fn<Args, Output = Return> + Send + Sync,
+ FactoryFunc: Fn<
+ (Arc<AsyncDIContainer>,),
+ Output = Box<(dyn Fn<Args, Output = Return> + Send + Sync)>,
+ > + Send
+ Sync,
{
let mut bindings_lock = self.di_container.bindings.lock().await;
@@ -315,7 +317,9 @@ where
FactoryFunc: Fn<
(Arc<AsyncDIContainer>,),
Output = Box<
- (dyn Fn<Args, Output = crate::future::BoxFuture<'static, Return>>),
+ (dyn Fn<Args, Output = crate::future::BoxFuture<'static, Return>>
+ + Send
+ + Sync),
>,
> + Send
+ Sync,
@@ -405,7 +409,7 @@ impl AsyncDIContainer
#[must_use]
pub fn bind<Interface>(self: &mut Arc<Self>) -> AsyncBindingBuilder<Interface>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
AsyncBindingBuilder::<Interface>::new(self.clone())
}
@@ -421,7 +425,7 @@ impl AsyncDIContainer
self: &Arc<Self>,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
self.get_bound::<Interface>(Vec::new(), None).await
}
@@ -438,7 +442,7 @@ impl AsyncDIContainer
name: &'static str,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
self.get_bound::<Interface>(Vec::new(), Some(name)).await
}
@@ -450,7 +454,7 @@ impl AsyncDIContainer
name: Option<&'static str>,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
let binding_providable = self
.get_binding_providable::<Interface>(name, dependency_history)
@@ -464,7 +468,7 @@ impl AsyncDIContainer
binding_providable: AsyncProvidable,
) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
match binding_providable {
AsyncProvidable::Transient(transient_binding) => {
@@ -553,7 +557,7 @@ impl AsyncDIContainer
dependency_history: Vec<&'static str>,
) -> Result<AsyncProvidable, AsyncDIContainerError>
where
- Interface: 'static + ?Sized,
+ Interface: 'static + ?Sized + Send + Sync,
{
let provider;
@@ -610,7 +614,7 @@ mod tests
use crate::interfaces::async_injectable::AsyncInjectable;
use crate::ptr::TransientPtr;
- pub trait IUserManager
+ pub trait IUserManager: Send + Sync
{
fn add_user(&self, user_id: i128);
@@ -658,7 +662,7 @@ mod tests
}
}
- pub trait INumber
+ pub trait INumber: Send + Sync
{
fn get(&self) -> i32;
@@ -845,8 +849,11 @@ mod tests
#[cfg(feature = "factory")]
async fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
{
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+ use crate as syrette;
+ use crate::factory;
+
+ #[factory(threadsafe = true)]
+ type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager;
let mut di_container = AsyncDIContainer::new();
@@ -877,8 +884,11 @@ mod tests
#[cfg(feature = "factory")]
async fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
{
- type IUserManagerFactory =
- dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+ use crate as syrette;
+ use crate::factory;
+
+ #[factory(threadsafe = true)]
+ type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager;
let mut di_container = AsyncDIContainer::new();
@@ -1144,7 +1154,7 @@ mod tests
#[cfg(feature = "factory")]
async fn can_get_factory() -> Result<(), Box<dyn Error>>
{
- trait IUserManager
+ trait IUserManager: Send + Sync
{
fn add_user(&mut self, user_id: i128);
@@ -1207,7 +1217,7 @@ mod tests
mock_provider.expect_do_clone().returning(|| {
type FactoryFunc = Box<
- (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>> + Send + Sync)
>;
let mut inner_mock_provider = MockProvider::new();
@@ -1254,7 +1264,7 @@ mod tests
#[cfg(feature = "factory")]
async fn can_get_factory_named() -> Result<(), Box<dyn Error>>
{
- trait IUserManager
+ trait IUserManager: Send + Sync
{
fn add_user(&mut self, user_id: i128);
@@ -1317,7 +1327,7 @@ mod tests
mock_provider.expect_do_clone().returning(|| {
type FactoryFunc = Box<
- (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>)
+ (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>> + Send + Sync)
>;
let mut inner_mock_provider = MockProvider::new();