diff options
author | HampusM <hampus@hampusmat.com> | 2023-09-18 20:35:55 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2023-09-18 20:35:55 +0200 |
commit | 6d729a4d20944b990c341149729a810a2898cdff (patch) | |
tree | f64218f129b5f7c168e64ede3b99fddb7faca8ac | |
parent | e4fdf58b42c61482741cb12e1faa24cbd50698e8 (diff) |
refactor: make threadsafe castable factory take DI container param
-rw-r--r-- | macros/src/factory/build_declare_interfaces.rs | 12 | ||||
-rw-r--r-- | src/di_container/asynchronous/mod.rs | 26 | ||||
-rw-r--r-- | src/private/castable_factory/threadsafe.rs | 109 | ||||
-rw-r--r-- | src/private/factory.rs | 9 |
4 files changed, 81 insertions, 75 deletions
diff --git a/macros/src/factory/build_declare_interfaces.rs b/macros/src/factory/build_declare_interfaces.rs index f256c7b..1e2d62e 100644 --- a/macros/src/factory/build_declare_interfaces.rs +++ b/macros/src/factory/build_declare_interfaces.rs @@ -12,19 +12,19 @@ pub fn build_declare_factory_interfaces( quote! { syrette::declare_interface!( syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory< - (std::sync::Arc<syrette::AsyncDIContainer>,), - #factory_interface + #factory_interface, + syrette::di_container::asynchronous::AsyncDIContainer, > -> syrette::private::factory::IThreadsafeFactory< - (std::sync::Arc<syrette::AsyncDIContainer>,), - #factory_interface + #factory_interface, + syrette::di_container::asynchronous::AsyncDIContainer, >, threadsafe_sharable = true ); syrette::declare_interface!( syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory< - (std::sync::Arc<syrette::AsyncDIContainer>,), - #factory_interface + #factory_interface, + syrette::di_container::asynchronous::AsyncDIContainer, > -> syrette::private::any_factory::AnyThreadsafeFactory, threadsafe_sharable = true ); diff --git a/src/di_container/asynchronous/mod.rs b/src/di_container/asynchronous/mod.rs index c2b4f6f..e651d81 100644 --- a/src/di_container/asynchronous/mod.rs +++ b/src/di_container/asynchronous/mod.rs @@ -279,7 +279,7 @@ impl AsyncDIContainer use crate::private::factory::IThreadsafeFactory; let factory = factory_binding - .cast::<dyn IThreadsafeFactory<(Arc<AsyncDIContainer>,), Interface>>() + .cast::<dyn IThreadsafeFactory<Interface, Self>>() .map_err(|err| match err { CastError::NotArcCastable(_) => { AsyncDIContainerError::InterfaceNotAsync( @@ -306,11 +306,13 @@ impl AsyncDIContainer use crate::private::factory::IThreadsafeFactory; use crate::ptr::TransientPtr; + type DefaultFactoryFn<Interface> = dyn IThreadsafeFactory< + dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync, + AsyncDIContainer, + >; + let default_factory = Self::cast_factory_binding::< - dyn IThreadsafeFactory< - (Arc<AsyncDIContainer>,), - dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync, - >, + DefaultFactoryFn<Interface>, >(binding, "default factory")?; Ok(SomePtr::Transient(default_factory(self.clone())())) @@ -321,13 +323,15 @@ impl AsyncDIContainer use crate::private::factory::IThreadsafeFactory; use crate::ptr::TransientPtr; + type AsyncDefaultFactoryFn<Interface> = dyn IThreadsafeFactory< + dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>> + + Send + + Sync, + AsyncDIContainer, + >; + let async_default_factory = Self::cast_factory_binding::< - dyn IThreadsafeFactory< - (Arc<AsyncDIContainer>,), - dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>> - + Send - + Sync, - >, + AsyncDefaultFactoryFn<Interface>, >( binding, "async default factory" )?; diff --git a/src/private/castable_factory/threadsafe.rs b/src/private/castable_factory/threadsafe.rs index 5b19844..cb8a04b 100644 --- a/src/private/castable_factory/threadsafe.rs +++ b/src/private/castable_factory/threadsafe.rs @@ -1,26 +1,29 @@ use std::any::type_name; use std::fmt::Debug; -use std::marker::Tuple; +use std::sync::Arc; use crate::private::any_factory::{AnyFactory, AnyThreadsafeFactory}; use crate::private::factory::IThreadsafeFactory; use crate::ptr::TransientPtr; -pub struct ThreadsafeCastableFactory<Args, ReturnInterface> +pub struct ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { - func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync), + func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>> + + Send + + Sync), } -impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> + ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { pub fn new( - func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>> + Send + Sync), ) -> Self @@ -29,94 +32,83 @@ where } } -impl<Args, ReturnInterface> IThreadsafeFactory<Args, ReturnInterface> - for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> IThreadsafeFactory<ReturnInterface, DIContainerT> + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { } -impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> Fn<(Arc<DIContainerT>,)> + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { - extern "rust-call" fn call(&self, args: Args) -> Self::Output + extern "rust-call" fn call(&self, args: (Arc<DIContainerT>,)) -> Self::Output { self.func.call(args) } } -impl<Args, ReturnInterface> FnMut<Args> - for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> FnMut<(Arc<DIContainerT>,)> + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { - extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output + extern "rust-call" fn call_mut(&mut self, args: (Arc<DIContainerT>,)) + -> Self::Output { self.call(args) } } -impl<Args, ReturnInterface> FnOnce<Args> - for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> FnOnce<(Arc<DIContainerT>,)> + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { type Output = TransientPtr<ReturnInterface>; - extern "rust-call" fn call_once(self, args: Args) -> Self::Output + extern "rust-call" fn call_once(self, args: (Arc<DIContainerT>,)) -> Self::Output { self.call(args) } } -impl<Args, ReturnInterface> AnyFactory - for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> AnyFactory + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { } -impl<Args, ReturnInterface> AnyThreadsafeFactory - for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> AnyThreadsafeFactory + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { } -impl<Args, ReturnInterface> Debug for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> Debug + for ThreadsafeCastableFactory<ReturnInterface, DIContainerT> where - Args: Tuple + 'static, + DIContainerT: 'static, ReturnInterface: 'static + ?Sized, { #[cfg(not(tarpaulin_include))] fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut args = type_name::<Args>(); - - if args.len() < 2 { - return Err(std::fmt::Error); - } - - args = args - .get(1..args.len() - 1) - .map_or_else(|| Err(std::fmt::Error), Ok)?; - - if args.ends_with(',') { - args = args - .get(..args.len() - 1) - .map_or_else(|| Err(std::fmt::Error), Ok)?; - } - let ret = type_name::<TransientPtr<ReturnInterface>>(); - formatter.write_fmt(format_args!("ThreadsafeCastableFactory ({args}) -> {ret}",)) + formatter.write_fmt(format_args!( + "ThreadsafeCastableFactory (Arc<AsyncDIContainer>) -> {ret} {{ ... }}", + )) } } @@ -124,6 +116,7 @@ where mod tests { use super::*; + use crate::di_container::asynchronous::MockAsyncDIContainer; #[derive(Debug, PartialEq, Eq)] struct Bacon @@ -134,11 +127,13 @@ mod tests #[test] fn can_call() { - let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { - TransientPtr::new(Bacon { heal_amount }) + let castable_factory = ThreadsafeCastableFactory::new(&|_| { + TransientPtr::new(Bacon { heal_amount: 27 }) }); - let output = castable_factory.call((27,)); + let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + + let output = castable_factory.call((mock_di_container,)); assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 27 })); } @@ -146,11 +141,13 @@ mod tests #[test] fn can_call_mut() { - let mut castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { - TransientPtr::new(Bacon { heal_amount }) + let mut castable_factory = ThreadsafeCastableFactory::new(&|_| { + TransientPtr::new(Bacon { heal_amount: 1092 }) }); - let output = castable_factory.call_mut((1092,)); + let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + + let output = castable_factory.call_mut((mock_di_container,)); assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 1092 })); } @@ -158,11 +155,13 @@ mod tests #[test] fn can_call_once() { - let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { - TransientPtr::new(Bacon { heal_amount }) + let castable_factory = ThreadsafeCastableFactory::new(&|_| { + TransientPtr::new(Bacon { heal_amount: 547 }) }); - let output = castable_factory.call_once((547,)); + let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + + let output = castable_factory.call_once((mock_di_container,)); assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 547 })); } diff --git a/src/private/factory.rs b/src/private/factory.rs index 94e1023..730338f 100644 --- a/src/private/factory.rs +++ b/src/private/factory.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "async")] +use std::sync::Arc; + use crate::private::cast::CastFrom; use crate::ptr::TransientPtr; @@ -11,10 +14,10 @@ where /// Interface for a threadsafe factory. #[cfg(feature = "async")] -pub trait IThreadsafeFactory<Args, ReturnInterface>: - Fn<Args, Output = TransientPtr<ReturnInterface>> + crate::private::cast::CastFromArc +pub trait IThreadsafeFactory<ReturnInterface, DIContainerT>: + Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>> + + crate::private::cast::CastFromArc where - Args: std::marker::Tuple, ReturnInterface: 'static + ?Sized, { } |