diff options
| -rw-r--r-- | macros/src/factory/build_declare_interfaces.rs | 12 | ||||
| -rw-r--r-- | src/di_container/asynchronous/mod.rs | 26 | ||||
| -rw-r--r-- | src/private/castable_factory/threadsafe.rs | 109 | ||||
| -rw-r--r-- | src/private/factory.rs | 9 | 
4 files changed, 81 insertions, 75 deletions
| diff --git a/macros/src/factory/build_declare_interfaces.rs b/macros/src/factory/build_declare_interfaces.rs index f256c7b..1e2d62e 100644 --- a/macros/src/factory/build_declare_interfaces.rs +++ b/macros/src/factory/build_declare_interfaces.rs @@ -12,19 +12,19 @@ pub fn build_declare_factory_interfaces(          quote! {              syrette::declare_interface!(                  syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory< -                    (std::sync::Arc<syrette::AsyncDIContainer>,), -                    #factory_interface +                    #factory_interface, +                    syrette::di_container::asynchronous::AsyncDIContainer,                  > -> syrette::private::factory::IThreadsafeFactory< -                    (std::sync::Arc<syrette::AsyncDIContainer>,), -                    #factory_interface +                    #factory_interface, +                    syrette::di_container::asynchronous::AsyncDIContainer,                  >,                  threadsafe_sharable = true              );              syrette::declare_interface!(                  syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory< -                    (std::sync::Arc<syrette::AsyncDIContainer>,), -                    #factory_interface +                    #factory_interface, +                    syrette::di_container::asynchronous::AsyncDIContainer,                  > -> syrette::private::any_factory::AnyThreadsafeFactory,                  threadsafe_sharable = true              ); diff --git a/src/di_container/asynchronous/mod.rs b/src/di_container/asynchronous/mod.rs index c2b4f6f..e651d81 100644 --- a/src/di_container/asynchronous/mod.rs +++ b/src/di_container/asynchronous/mod.rs @@ -279,7 +279,7 @@ impl AsyncDIContainer                  use crate::private::factory::IThreadsafeFactory;                  let factory = factory_binding -                    .cast::<dyn IThreadsafeFactory<(Arc<AsyncDIContainer>,), Interface>>() +                    .cast::<dyn IThreadsafeFactory<Interface, Self>>()                      .map_err(|err| match err {                          CastError::NotArcCastable(_) => {                              AsyncDIContainerError::InterfaceNotAsync( @@ -306,11 +306,13 @@ impl AsyncDIContainer                  use crate::private::factory::IThreadsafeFactory;                  use crate::ptr::TransientPtr; +                type DefaultFactoryFn<Interface> = dyn IThreadsafeFactory< +                    dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync, +                    AsyncDIContainer, +                >; +                  let default_factory = Self::cast_factory_binding::< -                    dyn IThreadsafeFactory< -                        (Arc<AsyncDIContainer>,), -                        dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync, -                    >, +                    DefaultFactoryFn<Interface>,                  >(binding, "default factory")?;                  Ok(SomePtr::Transient(default_factory(self.clone())())) @@ -321,13 +323,15 @@ impl AsyncDIContainer                  use crate::private::factory::IThreadsafeFactory;                  use crate::ptr::TransientPtr; +                type AsyncDefaultFactoryFn<Interface> = dyn IThreadsafeFactory< +                    dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>> +                        + Send +                        + Sync, +                    AsyncDIContainer, +                >; +                  let async_default_factory = Self::cast_factory_binding::< -                    dyn IThreadsafeFactory< -                        (Arc<AsyncDIContainer>,), -                        dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>> -                            + Send -                            + Sync, -                    >, +                    AsyncDefaultFactoryFn<Interface>,                  >(                      binding, "async default factory"                  )?; diff --git a/src/private/castable_factory/threadsafe.rs b/src/private/castable_factory/threadsafe.rs index 5b19844..cb8a04b 100644 --- a/src/private/castable_factory/threadsafe.rs +++ b/src/private/castable_factory/threadsafe.rs @@ -1,26 +1,29 @@  use std::any::type_name;  use std::fmt::Debug; -use std::marker::Tuple; +use std::sync::Arc;  use crate::private::any_factory::{AnyFactory, AnyThreadsafeFactory};  use crate::private::factory::IThreadsafeFactory;  use crate::ptr::TransientPtr; -pub struct ThreadsafeCastableFactory<Args, ReturnInterface> +pub struct ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  { -    func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync), +    func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>> +                  + Send +                  + Sync),  } -impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> +    ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {      pub fn new( -        func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> +        func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>>                        + Send                        + Sync),      ) -> Self @@ -29,94 +32,83 @@ where      }  } -impl<Args, ReturnInterface> IThreadsafeFactory<Args, ReturnInterface> -    for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> IThreadsafeFactory<ReturnInterface, DIContainerT> +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {  } -impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> Fn<(Arc<DIContainerT>,)> +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  { -    extern "rust-call" fn call(&self, args: Args) -> Self::Output +    extern "rust-call" fn call(&self, args: (Arc<DIContainerT>,)) -> Self::Output      {          self.func.call(args)      }  } -impl<Args, ReturnInterface> FnMut<Args> -    for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> FnMut<(Arc<DIContainerT>,)> +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  { -    extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output +    extern "rust-call" fn call_mut(&mut self, args: (Arc<DIContainerT>,)) +        -> Self::Output      {          self.call(args)      }  } -impl<Args, ReturnInterface> FnOnce<Args> -    for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> FnOnce<(Arc<DIContainerT>,)> +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {      type Output = TransientPtr<ReturnInterface>; -    extern "rust-call" fn call_once(self, args: Args) -> Self::Output +    extern "rust-call" fn call_once(self, args: (Arc<DIContainerT>,)) -> Self::Output      {          self.call(args)      }  } -impl<Args, ReturnInterface> AnyFactory -    for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> AnyFactory +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {  } -impl<Args, ReturnInterface> AnyThreadsafeFactory -    for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> AnyThreadsafeFactory +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {  } -impl<Args, ReturnInterface> Debug for ThreadsafeCastableFactory<Args, ReturnInterface> +impl<ReturnInterface, DIContainerT> Debug +    for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>  where -    Args: Tuple + 'static, +    DIContainerT: 'static,      ReturnInterface: 'static + ?Sized,  {      #[cfg(not(tarpaulin_include))]      fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result      { -        let mut args = type_name::<Args>(); - -        if args.len() < 2 { -            return Err(std::fmt::Error); -        } - -        args = args -            .get(1..args.len() - 1) -            .map_or_else(|| Err(std::fmt::Error), Ok)?; - -        if args.ends_with(',') { -            args = args -                .get(..args.len() - 1) -                .map_or_else(|| Err(std::fmt::Error), Ok)?; -        } -          let ret = type_name::<TransientPtr<ReturnInterface>>(); -        formatter.write_fmt(format_args!("ThreadsafeCastableFactory ({args}) -> {ret}",)) +        formatter.write_fmt(format_args!( +            "ThreadsafeCastableFactory (Arc<AsyncDIContainer>) -> {ret} {{ ... }}", +        ))      }  } @@ -124,6 +116,7 @@ where  mod tests  {      use super::*; +    use crate::di_container::asynchronous::MockAsyncDIContainer;      #[derive(Debug, PartialEq, Eq)]      struct Bacon @@ -134,11 +127,13 @@ mod tests      #[test]      fn can_call()      { -        let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { -            TransientPtr::new(Bacon { heal_amount }) +        let castable_factory = ThreadsafeCastableFactory::new(&|_| { +            TransientPtr::new(Bacon { heal_amount: 27 })          }); -        let output = castable_factory.call((27,)); +        let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + +        let output = castable_factory.call((mock_di_container,));          assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 27 }));      } @@ -146,11 +141,13 @@ mod tests      #[test]      fn can_call_mut()      { -        let mut castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { -            TransientPtr::new(Bacon { heal_amount }) +        let mut castable_factory = ThreadsafeCastableFactory::new(&|_| { +            TransientPtr::new(Bacon { heal_amount: 1092 })          }); -        let output = castable_factory.call_mut((1092,)); +        let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + +        let output = castable_factory.call_mut((mock_di_container,));          assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 1092 }));      } @@ -158,11 +155,13 @@ mod tests      #[test]      fn can_call_once()      { -        let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| { -            TransientPtr::new(Bacon { heal_amount }) +        let castable_factory = ThreadsafeCastableFactory::new(&|_| { +            TransientPtr::new(Bacon { heal_amount: 547 })          }); -        let output = castable_factory.call_once((547,)); +        let mock_di_container = Arc::new(MockAsyncDIContainer::new()); + +        let output = castable_factory.call_once((mock_di_container,));          assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 547 }));      } diff --git a/src/private/factory.rs b/src/private/factory.rs index 94e1023..730338f 100644 --- a/src/private/factory.rs +++ b/src/private/factory.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "async")] +use std::sync::Arc; +  use crate::private::cast::CastFrom;  use crate::ptr::TransientPtr; @@ -11,10 +14,10 @@ where  /// Interface for a threadsafe factory.  #[cfg(feature = "async")] -pub trait IThreadsafeFactory<Args, ReturnInterface>: -    Fn<Args, Output = TransientPtr<ReturnInterface>> + crate::private::cast::CastFromArc +pub trait IThreadsafeFactory<ReturnInterface, DIContainerT>: +    Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>> +    + crate::private::cast::CastFromArc  where -    Args: std::marker::Tuple,      ReturnInterface: 'static + ?Sized,  {  } | 
