diff options
author | HampusM <hampus@hampusmat.com> | 2022-09-23 22:19:08 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-09-23 22:19:08 +0200 |
commit | 3ed020425bfd1fc5fedfa89a7ce20207bedcf5bc (patch) | |
tree | 7ac971df9005f82445a4d01e4c5dec2c938a63d9 | |
parent | 145a257775f2397ceba0941ac2a2642cf3382dcb (diff) |
fix: prevent problems caused by non send + sync traits
-rw-r--r-- | macros/src/fn_trait.rs | 20 | ||||
-rw-r--r-- | macros/src/lib.rs | 7 | ||||
-rw-r--r-- | src/async_di_container.rs | 64 |
3 files changed, 62 insertions, 29 deletions
diff --git a/macros/src/fn_trait.rs b/macros/src/fn_trait.rs index f9b3514..9820f02 100644 --- a/macros/src/fn_trait.rs +++ b/macros/src/fn_trait.rs @@ -2,7 +2,7 @@ use quote::ToTokens; use syn::parse::Parse; use syn::punctuated::Punctuated; use syn::token::Paren; -use syn::{parenthesized, Ident, Token, Type}; +use syn::{parenthesized, parse_str, Ident, Token, TraitBound, Type}; /// A function trait. `dyn Fn(u32) -> String` #[derive(Debug, Clone)] @@ -14,6 +14,15 @@ pub struct FnTrait pub inputs: Punctuated<Type, Token![,]>, pub r_arrow_token: Token![->], pub output: Type, + pub trait_bounds: Punctuated<TraitBound, Token![+]>, +} + +impl FnTrait +{ + pub fn add_trait_bound(&mut self, trait_bound: TraitBound) + { + self.trait_bounds.push(trait_bound); + } } impl Parse for FnTrait @@ -45,6 +54,7 @@ impl Parse for FnTrait inputs, r_arrow_token, output, + trait_bounds: Punctuated::new(), }) } } @@ -64,5 +74,13 @@ impl ToTokens for FnTrait self.r_arrow_token.to_tokens(tokens); self.output.to_tokens(tokens); + + if !self.trait_bounds.is_empty() { + let plus: Token![+] = parse_str("+").unwrap(); + + plus.to_tokens(tokens); + + self.trait_bounds.to_tokens(tokens); + } } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2cd57f0..b0ccc86 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse, parse_macro_input}; +use syn::{parse, parse_macro_input, parse_str}; mod decl_def_factory_args; mod declare_interface_args; @@ -236,6 +236,11 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke ) .unwrap(); + if is_threadsafe { + factory_interface.add_trait_bound(parse_str("Send").unwrap()); + factory_interface.add_trait_bound(parse_str("Sync").unwrap()); + } + type_alias.ty = Box::new(Type::Verbatim(factory_interface.to_token_stream())); let decl_interfaces = if is_threadsafe { diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 7e01c66..c67900e 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -7,7 +7,7 @@ //! //! use syrette::{injectable, AsyncDIContainer}; //! -//! trait IDatabaseService +//! trait IDatabaseService: Send + Sync //! { //! fn get_all_records(&self, table_name: String) -> HashMap<String, String>; //! } @@ -83,7 +83,7 @@ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; /// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. pub struct AsyncBindingWhenConfigurator<Interface> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { di_container: Arc<AsyncDIContainer>, interface_phantom: PhantomData<Interface>, @@ -91,7 +91,7 @@ where impl<Interface> AsyncBindingWhenConfigurator<Interface> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { fn new(di_container: Arc<AsyncDIContainer>) -> Self { @@ -130,7 +130,7 @@ where /// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. pub struct AsyncBindingScopeConfigurator<Interface, Implementation> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, Implementation: AsyncInjectable, { di_container: Arc<AsyncDIContainer>, @@ -140,7 +140,7 @@ where impl<Interface, Implementation> AsyncBindingScopeConfigurator<Interface, Implementation> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, Implementation: AsyncInjectable, { fn new(di_container: Arc<AsyncDIContainer>) -> Self @@ -196,7 +196,7 @@ where /// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. pub struct AsyncBindingBuilder<Interface> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { di_container: Arc<AsyncDIContainer>, interface_phantom: PhantomData<Interface>, @@ -204,7 +204,7 @@ where impl<Interface> AsyncBindingBuilder<Interface> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { fn new(di_container: Arc<AsyncDIContainer>) -> Self { @@ -267,9 +267,11 @@ where where Args: 'static, Return: 'static + ?Sized, - Interface: Fn<Args, Output = Return>, - FactoryFunc: Fn<(Arc<AsyncDIContainer>,), Output = Box<(dyn Fn<Args, Output = Return>)>> - + Send + Interface: Fn<Args, Output = Return> + Send + Sync, + FactoryFunc: Fn< + (Arc<AsyncDIContainer>,), + Output = Box<(dyn Fn<Args, Output = Return> + Send + Sync)>, + > + Send + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -315,7 +317,9 @@ where FactoryFunc: Fn< (Arc<AsyncDIContainer>,), Output = Box< - (dyn Fn<Args, Output = crate::future::BoxFuture<'static, Return>>), + (dyn Fn<Args, Output = crate::future::BoxFuture<'static, Return>> + + Send + + Sync), >, > + Send + Sync, @@ -405,7 +409,7 @@ impl AsyncDIContainer #[must_use] pub fn bind<Interface>(self: &mut Arc<Self>) -> AsyncBindingBuilder<Interface> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { AsyncBindingBuilder::<Interface>::new(self.clone()) } @@ -421,7 +425,7 @@ impl AsyncDIContainer self: &Arc<Self>, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { self.get_bound::<Interface>(Vec::new(), None).await } @@ -438,7 +442,7 @@ impl AsyncDIContainer name: &'static str, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { self.get_bound::<Interface>(Vec::new(), Some(name)).await } @@ -450,7 +454,7 @@ impl AsyncDIContainer name: Option<&'static str>, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { let binding_providable = self .get_binding_providable::<Interface>(name, dependency_history) @@ -464,7 +468,7 @@ impl AsyncDIContainer binding_providable: AsyncProvidable, ) -> Result<SomeThreadsafePtr<Interface>, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { match binding_providable { AsyncProvidable::Transient(transient_binding) => { @@ -553,7 +557,7 @@ impl AsyncDIContainer dependency_history: Vec<&'static str>, ) -> Result<AsyncProvidable, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { let provider; @@ -610,7 +614,7 @@ mod tests use crate::interfaces::async_injectable::AsyncInjectable; use crate::ptr::TransientPtr; - pub trait IUserManager + pub trait IUserManager: Send + Sync { fn add_user(&self, user_id: i128); @@ -658,7 +662,7 @@ mod tests } } - pub trait INumber + pub trait INumber: Send + Sync { fn get(&self) -> i32; @@ -845,8 +849,11 @@ mod tests #[cfg(feature = "factory")] async fn can_bind_to_factory() -> Result<(), Box<dyn Error>> { - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + use crate as syrette; + use crate::factory; + + #[factory(threadsafe = true)] + type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager; let mut di_container = AsyncDIContainer::new(); @@ -877,8 +884,11 @@ mod tests #[cfg(feature = "factory")] async fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>> { - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + use crate as syrette; + use crate::factory; + + #[factory(threadsafe = true)] + type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager; let mut di_container = AsyncDIContainer::new(); @@ -1144,7 +1154,7 @@ mod tests #[cfg(feature = "factory")] async fn can_get_factory() -> Result<(), Box<dyn Error>> { - trait IUserManager + trait IUserManager: Send + Sync { fn add_user(&mut self, user_id: i128); @@ -1207,7 +1217,7 @@ mod tests mock_provider.expect_do_clone().returning(|| { type FactoryFunc = Box< - (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>) + (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>> + Send + Sync) >; let mut inner_mock_provider = MockProvider::new(); @@ -1254,7 +1264,7 @@ mod tests #[cfg(feature = "factory")] async fn can_get_factory_named() -> Result<(), Box<dyn Error>> { - trait IUserManager + trait IUserManager: Send + Sync { fn add_user(&mut self, user_id: i128); @@ -1317,7 +1327,7 @@ mod tests mock_provider.expect_do_clone().returning(|| { type FactoryFunc = Box< - (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>>) + (dyn Fn<(Vec<i128>,), Output = TransientPtr<dyn IUserManager>> + Send + Sync) >; let mut inner_mock_provider = MockProvider::new(); |