From 3ed020425bfd1fc5fedfa89a7ce20207bedcf5bc Mon Sep 17 00:00:00 2001 From: HampusM Date: Fri, 23 Sep 2022 22:19:08 +0200 Subject: fix: prevent problems caused by non send + sync traits --- macros/src/fn_trait.rs | 20 ++++++++++++++- macros/src/lib.rs | 7 +++++- src/async_di_container.rs | 64 +++++++++++++++++++++++++++-------------------- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/macros/src/fn_trait.rs b/macros/src/fn_trait.rs index f9b3514..9820f02 100644 --- a/macros/src/fn_trait.rs +++ b/macros/src/fn_trait.rs @@ -2,7 +2,7 @@ use quote::ToTokens; use syn::parse::Parse; use syn::punctuated::Punctuated; use syn::token::Paren; -use syn::{parenthesized, Ident, Token, Type}; +use syn::{parenthesized, parse_str, Ident, Token, TraitBound, Type}; /// A function trait. `dyn Fn(u32) -> String` #[derive(Debug, Clone)] @@ -14,6 +14,15 @@ pub struct FnTrait pub inputs: Punctuated, pub r_arrow_token: Token![->], pub output: Type, + pub trait_bounds: Punctuated, +} + +impl FnTrait +{ + pub fn add_trait_bound(&mut self, trait_bound: TraitBound) + { + self.trait_bounds.push(trait_bound); + } } impl Parse for FnTrait @@ -45,6 +54,7 @@ impl Parse for FnTrait inputs, r_arrow_token, output, + trait_bounds: Punctuated::new(), }) } } @@ -64,5 +74,13 @@ impl ToTokens for FnTrait self.r_arrow_token.to_tokens(tokens); self.output.to_tokens(tokens); + + if !self.trait_bounds.is_empty() { + let plus: Token![+] = parse_str("+").unwrap(); + + plus.to_tokens(tokens); + + self.trait_bounds.to_tokens(tokens); + } } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2cd57f0..b0ccc86 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse, parse_macro_input}; +use syn::{parse, parse_macro_input, parse_str}; mod decl_def_factory_args; mod declare_interface_args; @@ -236,6 +236,11 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke ) .unwrap(); + if is_threadsafe { + factory_interface.add_trait_bound(parse_str("Send").unwrap()); + factory_interface.add_trait_bound(parse_str("Sync").unwrap()); + } + type_alias.ty = Box::new(Type::Verbatim(factory_interface.to_token_stream())); let decl_interfaces = if is_threadsafe { diff --git a/src/async_di_container.rs b/src/async_di_container.rs index 7e01c66..c67900e 100644 --- a/src/async_di_container.rs +++ b/src/async_di_container.rs @@ -7,7 +7,7 @@ //! //! use syrette::{injectable, AsyncDIContainer}; //! -//! trait IDatabaseService +//! trait IDatabaseService: Send + Sync //! { //! fn get_all_records(&self, table_name: String) -> HashMap; //! } @@ -83,7 +83,7 @@ use crate::ptr::{SomeThreadsafePtr, ThreadsafeSingletonPtr}; /// When configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. pub struct AsyncBindingWhenConfigurator where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { di_container: Arc, interface_phantom: PhantomData, @@ -91,7 +91,7 @@ where impl AsyncBindingWhenConfigurator where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { fn new(di_container: Arc) -> Self { @@ -130,7 +130,7 @@ where /// Scope configurator for a binding for type 'Interface' inside a [`AsyncDIContainer`]. pub struct AsyncBindingScopeConfigurator where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, Implementation: AsyncInjectable, { di_container: Arc, @@ -140,7 +140,7 @@ where impl AsyncBindingScopeConfigurator where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, Implementation: AsyncInjectable, { fn new(di_container: Arc) -> Self @@ -196,7 +196,7 @@ where /// Binding builder for type `Interface` inside a [`AsyncDIContainer`]. pub struct AsyncBindingBuilder where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { di_container: Arc, interface_phantom: PhantomData, @@ -204,7 +204,7 @@ where impl AsyncBindingBuilder where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { fn new(di_container: Arc) -> Self { @@ -267,9 +267,11 @@ where where Args: 'static, Return: 'static + ?Sized, - Interface: Fn, - FactoryFunc: Fn<(Arc,), Output = Box<(dyn Fn)>> - + Send + Interface: Fn + Send + Sync, + FactoryFunc: Fn< + (Arc,), + Output = Box<(dyn Fn + Send + Sync)>, + > + Send + Sync, { let mut bindings_lock = self.di_container.bindings.lock().await; @@ -315,7 +317,9 @@ where FactoryFunc: Fn< (Arc,), Output = Box< - (dyn Fn>), + (dyn Fn> + + Send + + Sync), >, > + Send + Sync, @@ -405,7 +409,7 @@ impl AsyncDIContainer #[must_use] pub fn bind(self: &mut Arc) -> AsyncBindingBuilder where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { AsyncBindingBuilder::::new(self.clone()) } @@ -421,7 +425,7 @@ impl AsyncDIContainer self: &Arc, ) -> Result, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { self.get_bound::(Vec::new(), None).await } @@ -438,7 +442,7 @@ impl AsyncDIContainer name: &'static str, ) -> Result, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { self.get_bound::(Vec::new(), Some(name)).await } @@ -450,7 +454,7 @@ impl AsyncDIContainer name: Option<&'static str>, ) -> Result, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { let binding_providable = self .get_binding_providable::(name, dependency_history) @@ -464,7 +468,7 @@ impl AsyncDIContainer binding_providable: AsyncProvidable, ) -> Result, AsyncDIContainerError> where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { match binding_providable { AsyncProvidable::Transient(transient_binding) => { @@ -553,7 +557,7 @@ impl AsyncDIContainer dependency_history: Vec<&'static str>, ) -> Result where - Interface: 'static + ?Sized, + Interface: 'static + ?Sized + Send + Sync, { let provider; @@ -610,7 +614,7 @@ mod tests use crate::interfaces::async_injectable::AsyncInjectable; use crate::ptr::TransientPtr; - pub trait IUserManager + pub trait IUserManager: Send + Sync { fn add_user(&self, user_id: i128); @@ -658,7 +662,7 @@ mod tests } } - pub trait INumber + pub trait INumber: Send + Sync { fn get(&self) -> i32; @@ -845,8 +849,11 @@ mod tests #[cfg(feature = "factory")] async fn can_bind_to_factory() -> Result<(), Box> { - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + use crate as syrette; + use crate::factory; + + #[factory(threadsafe = true)] + type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager; let mut di_container = AsyncDIContainer::new(); @@ -877,8 +884,11 @@ mod tests #[cfg(feature = "factory")] async fn can_bind_to_factory_when_named() -> Result<(), Box> { - type IUserManagerFactory = - dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + use crate as syrette; + use crate::factory; + + #[factory(threadsafe = true)] + type IUserManagerFactory = dyn Fn() -> dyn subjects::IUserManager; let mut di_container = AsyncDIContainer::new(); @@ -1144,7 +1154,7 @@ mod tests #[cfg(feature = "factory")] async fn can_get_factory() -> Result<(), Box> { - trait IUserManager + trait IUserManager: Send + Sync { fn add_user(&mut self, user_id: i128); @@ -1207,7 +1217,7 @@ mod tests mock_provider.expect_do_clone().returning(|| { type FactoryFunc = Box< - (dyn Fn<(Vec,), Output = TransientPtr>) + (dyn Fn<(Vec,), Output = TransientPtr> + Send + Sync) >; let mut inner_mock_provider = MockProvider::new(); @@ -1254,7 +1264,7 @@ mod tests #[cfg(feature = "factory")] async fn can_get_factory_named() -> Result<(), Box> { - trait IUserManager + trait IUserManager: Send + Sync { fn add_user(&mut self, user_id: i128); @@ -1317,7 +1327,7 @@ mod tests mock_provider.expect_do_clone().returning(|| { type FactoryFunc = Box< - (dyn Fn<(Vec,), Output = TransientPtr>) + (dyn Fn<(Vec,), Output = TransientPtr> + Send + Sync) >; let mut inner_mock_provider = MockProvider::new(); -- cgit v1.2.3-18-g5258