aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2023-09-18 20:35:55 +0200
committerHampusM <hampus@hampusmat.com>2023-09-18 20:35:55 +0200
commit6d729a4d20944b990c341149729a810a2898cdff (patch)
treef64218f129b5f7c168e64ede3b99fddb7faca8ac
parente4fdf58b42c61482741cb12e1faa24cbd50698e8 (diff)
refactor: make threadsafe castable factory take DI container param
-rw-r--r--macros/src/factory/build_declare_interfaces.rs12
-rw-r--r--src/di_container/asynchronous/mod.rs26
-rw-r--r--src/private/castable_factory/threadsafe.rs109
-rw-r--r--src/private/factory.rs9
4 files changed, 81 insertions, 75 deletions
diff --git a/macros/src/factory/build_declare_interfaces.rs b/macros/src/factory/build_declare_interfaces.rs
index f256c7b..1e2d62e 100644
--- a/macros/src/factory/build_declare_interfaces.rs
+++ b/macros/src/factory/build_declare_interfaces.rs
@@ -12,19 +12,19 @@ pub fn build_declare_factory_interfaces(
quote! {
syrette::declare_interface!(
syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory<
- (std::sync::Arc<syrette::AsyncDIContainer>,),
- #factory_interface
+ #factory_interface,
+ syrette::di_container::asynchronous::AsyncDIContainer,
> -> syrette::private::factory::IThreadsafeFactory<
- (std::sync::Arc<syrette::AsyncDIContainer>,),
- #factory_interface
+ #factory_interface,
+ syrette::di_container::asynchronous::AsyncDIContainer,
>,
threadsafe_sharable = true
);
syrette::declare_interface!(
syrette::private::castable_factory::threadsafe::ThreadsafeCastableFactory<
- (std::sync::Arc<syrette::AsyncDIContainer>,),
- #factory_interface
+ #factory_interface,
+ syrette::di_container::asynchronous::AsyncDIContainer,
> -> syrette::private::any_factory::AnyThreadsafeFactory,
threadsafe_sharable = true
);
diff --git a/src/di_container/asynchronous/mod.rs b/src/di_container/asynchronous/mod.rs
index c2b4f6f..e651d81 100644
--- a/src/di_container/asynchronous/mod.rs
+++ b/src/di_container/asynchronous/mod.rs
@@ -279,7 +279,7 @@ impl AsyncDIContainer
use crate::private::factory::IThreadsafeFactory;
let factory = factory_binding
- .cast::<dyn IThreadsafeFactory<(Arc<AsyncDIContainer>,), Interface>>()
+ .cast::<dyn IThreadsafeFactory<Interface, Self>>()
.map_err(|err| match err {
CastError::NotArcCastable(_) => {
AsyncDIContainerError::InterfaceNotAsync(
@@ -306,11 +306,13 @@ impl AsyncDIContainer
use crate::private::factory::IThreadsafeFactory;
use crate::ptr::TransientPtr;
+ type DefaultFactoryFn<Interface> = dyn IThreadsafeFactory<
+ dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync,
+ AsyncDIContainer,
+ >;
+
let default_factory = Self::cast_factory_binding::<
- dyn IThreadsafeFactory<
- (Arc<AsyncDIContainer>,),
- dyn Fn<(), Output = TransientPtr<Interface>> + Send + Sync,
- >,
+ DefaultFactoryFn<Interface>,
>(binding, "default factory")?;
Ok(SomePtr::Transient(default_factory(self.clone())()))
@@ -321,13 +323,15 @@ impl AsyncDIContainer
use crate::private::factory::IThreadsafeFactory;
use crate::ptr::TransientPtr;
+ type AsyncDefaultFactoryFn<Interface> = dyn IThreadsafeFactory<
+ dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>>
+ + Send
+ + Sync,
+ AsyncDIContainer,
+ >;
+
let async_default_factory = Self::cast_factory_binding::<
- dyn IThreadsafeFactory<
- (Arc<AsyncDIContainer>,),
- dyn Fn<(), Output = BoxFuture<'static, TransientPtr<Interface>>>
- + Send
- + Sync,
- >,
+ AsyncDefaultFactoryFn<Interface>,
>(
binding, "async default factory"
)?;
diff --git a/src/private/castable_factory/threadsafe.rs b/src/private/castable_factory/threadsafe.rs
index 5b19844..cb8a04b 100644
--- a/src/private/castable_factory/threadsafe.rs
+++ b/src/private/castable_factory/threadsafe.rs
@@ -1,26 +1,29 @@
use std::any::type_name;
use std::fmt::Debug;
-use std::marker::Tuple;
+use std::sync::Arc;
use crate::private::any_factory::{AnyFactory, AnyThreadsafeFactory};
use crate::private::factory::IThreadsafeFactory;
use crate::ptr::TransientPtr;
-pub struct ThreadsafeCastableFactory<Args, ReturnInterface>
+pub struct ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
- func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>> + Send + Sync),
+ func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>>
+ + Send
+ + Sync),
}
-impl<Args, ReturnInterface> ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT>
+ ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
pub fn new(
- func: &'static (dyn Fn<Args, Output = TransientPtr<ReturnInterface>>
+ func: &'static (dyn Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>>
+ Send
+ Sync),
) -> Self
@@ -29,94 +32,83 @@ where
}
}
-impl<Args, ReturnInterface> IThreadsafeFactory<Args, ReturnInterface>
- for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> IThreadsafeFactory<ReturnInterface, DIContainerT>
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
}
-impl<Args, ReturnInterface> Fn<Args> for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> Fn<(Arc<DIContainerT>,)>
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
- extern "rust-call" fn call(&self, args: Args) -> Self::Output
+ extern "rust-call" fn call(&self, args: (Arc<DIContainerT>,)) -> Self::Output
{
self.func.call(args)
}
}
-impl<Args, ReturnInterface> FnMut<Args>
- for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> FnMut<(Arc<DIContainerT>,)>
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
- extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output
+ extern "rust-call" fn call_mut(&mut self, args: (Arc<DIContainerT>,))
+ -> Self::Output
{
self.call(args)
}
}
-impl<Args, ReturnInterface> FnOnce<Args>
- for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> FnOnce<(Arc<DIContainerT>,)>
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
type Output = TransientPtr<ReturnInterface>;
- extern "rust-call" fn call_once(self, args: Args) -> Self::Output
+ extern "rust-call" fn call_once(self, args: (Arc<DIContainerT>,)) -> Self::Output
{
self.call(args)
}
}
-impl<Args, ReturnInterface> AnyFactory
- for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> AnyFactory
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
}
-impl<Args, ReturnInterface> AnyThreadsafeFactory
- for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> AnyThreadsafeFactory
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
}
-impl<Args, ReturnInterface> Debug for ThreadsafeCastableFactory<Args, ReturnInterface>
+impl<ReturnInterface, DIContainerT> Debug
+ for ThreadsafeCastableFactory<ReturnInterface, DIContainerT>
where
- Args: Tuple + 'static,
+ DIContainerT: 'static,
ReturnInterface: 'static + ?Sized,
{
#[cfg(not(tarpaulin_include))]
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
{
- let mut args = type_name::<Args>();
-
- if args.len() < 2 {
- return Err(std::fmt::Error);
- }
-
- args = args
- .get(1..args.len() - 1)
- .map_or_else(|| Err(std::fmt::Error), Ok)?;
-
- if args.ends_with(',') {
- args = args
- .get(..args.len() - 1)
- .map_or_else(|| Err(std::fmt::Error), Ok)?;
- }
-
let ret = type_name::<TransientPtr<ReturnInterface>>();
- formatter.write_fmt(format_args!("ThreadsafeCastableFactory ({args}) -> {ret}",))
+ formatter.write_fmt(format_args!(
+ "ThreadsafeCastableFactory (Arc<AsyncDIContainer>) -> {ret} {{ ... }}",
+ ))
}
}
@@ -124,6 +116,7 @@ where
mod tests
{
use super::*;
+ use crate::di_container::asynchronous::MockAsyncDIContainer;
#[derive(Debug, PartialEq, Eq)]
struct Bacon
@@ -134,11 +127,13 @@ mod tests
#[test]
fn can_call()
{
- let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| {
- TransientPtr::new(Bacon { heal_amount })
+ let castable_factory = ThreadsafeCastableFactory::new(&|_| {
+ TransientPtr::new(Bacon { heal_amount: 27 })
});
- let output = castable_factory.call((27,));
+ let mock_di_container = Arc::new(MockAsyncDIContainer::new());
+
+ let output = castable_factory.call((mock_di_container,));
assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 27 }));
}
@@ -146,11 +141,13 @@ mod tests
#[test]
fn can_call_mut()
{
- let mut castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| {
- TransientPtr::new(Bacon { heal_amount })
+ let mut castable_factory = ThreadsafeCastableFactory::new(&|_| {
+ TransientPtr::new(Bacon { heal_amount: 1092 })
});
- let output = castable_factory.call_mut((1092,));
+ let mock_di_container = Arc::new(MockAsyncDIContainer::new());
+
+ let output = castable_factory.call_mut((mock_di_container,));
assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 1092 }));
}
@@ -158,11 +155,13 @@ mod tests
#[test]
fn can_call_once()
{
- let castable_factory = ThreadsafeCastableFactory::new(&|heal_amount| {
- TransientPtr::new(Bacon { heal_amount })
+ let castable_factory = ThreadsafeCastableFactory::new(&|_| {
+ TransientPtr::new(Bacon { heal_amount: 547 })
});
- let output = castable_factory.call_once((547,));
+ let mock_di_container = Arc::new(MockAsyncDIContainer::new());
+
+ let output = castable_factory.call_once((mock_di_container,));
assert_eq!(output, TransientPtr::new(Bacon { heal_amount: 547 }));
}
diff --git a/src/private/factory.rs b/src/private/factory.rs
index 94e1023..730338f 100644
--- a/src/private/factory.rs
+++ b/src/private/factory.rs
@@ -1,3 +1,6 @@
+#[cfg(feature = "async")]
+use std::sync::Arc;
+
use crate::private::cast::CastFrom;
use crate::ptr::TransientPtr;
@@ -11,10 +14,10 @@ where
/// Interface for a threadsafe factory.
#[cfg(feature = "async")]
-pub trait IThreadsafeFactory<Args, ReturnInterface>:
- Fn<Args, Output = TransientPtr<ReturnInterface>> + crate::private::cast::CastFromArc
+pub trait IThreadsafeFactory<ReturnInterface, DIContainerT>:
+ Fn<(Arc<DIContainerT>,), Output = TransientPtr<ReturnInterface>>
+ + crate::private::cast::CastFromArc
where
- Args: std::marker::Tuple,
ReturnInterface: 'static + ?Sized,
{
}