diff options
-rw-r--r-- | src/di_container/asynchronous.rs | 41 | ||||
-rw-r--r-- | src/di_container/asynchronous/binding/builder.rs | 40 | ||||
-rw-r--r-- | src/di_container/blocking.rs | 39 | ||||
-rw-r--r-- | src/di_container/blocking/binding/builder.rs | 18 | ||||
-rw-r--r-- | src/provider/async.rs | 122 | ||||
-rw-r--r-- | src/provider/blocking.rs | 77 |
6 files changed, 194 insertions, 143 deletions
diff --git a/src/di_container/asynchronous.rs b/src/di_container/asynchronous.rs index c993b8b..c6308e6 100644 --- a/src/di_container/asynchronous.rs +++ b/src/di_container/asynchronous.rs @@ -347,10 +347,13 @@ impl AsyncDIContainer )) } #[cfg(feature = "factory")] - AsyncProvidable::Factory(factory_binding) => { + AsyncProvidable::Function( + func_bound, + crate::provider::r#async::ProvidableFunctionKind::UserCalled, + ) => { use crate::castable_function::threadsafe::ThreadsafeCastableFunction; - let factory = factory_binding + let factory = func_bound .as_any() .downcast_ref::<ThreadsafeCastableFunction<Interface, Self>>() .ok_or_else(|| AsyncDIContainerError::CastFailed { @@ -361,7 +364,10 @@ impl AsyncDIContainer Ok(SomePtr::ThreadsafeFactory(factory.call(self).into())) } #[cfg(feature = "factory")] - AsyncProvidable::DefaultFactory(binding) => { + AsyncProvidable::Function( + func_bound, + crate::provider::r#async::ProvidableFunctionKind::Instant, + ) => { use crate::castable_function::threadsafe::ThreadsafeCastableFunction; use crate::ptr::TransientPtr; @@ -370,7 +376,7 @@ impl AsyncDIContainer AsyncDIContainer, >; - let default_factory = binding + let default_factory = func_bound .as_any() .downcast_ref::<DefaultFactoryFn<Interface>>() .ok_or_else(|| AsyncDIContainerError::CastFailed { @@ -381,7 +387,10 @@ impl AsyncDIContainer Ok(SomePtr::Transient(default_factory.call(self)())) } #[cfg(feature = "factory")] - AsyncProvidable::AsyncDefaultFactory(binding) => { + AsyncProvidable::Function( + func_bound, + crate::provider::r#async::ProvidableFunctionKind::AsyncInstant, + ) => { use crate::castable_function::threadsafe::ThreadsafeCastableFunction; use crate::future::BoxFuture; use crate::ptr::TransientPtr; @@ -393,7 +402,7 @@ impl AsyncDIContainer AsyncDIContainer, >; - let async_default_factory = binding + let async_default_factory = func_bound .as_any() .downcast_ref::<AsyncDefaultFactoryFn<Interface>>() .ok_or_else(|| AsyncDIContainerError::CastFailed { @@ -652,7 +661,10 @@ mod tests } } + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; + use crate::provider::r#async::ProvidableFunctionKind; type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager> + Send + Sync; @@ -672,10 +684,9 @@ mod tests }; inner_mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new( - ThreadsafeCastableFunction::new(factory_func), - ), + Ok(AsyncProvidable::Function( + Arc::new(ThreadsafeCastableFunction::new(factory_func)), + ProvidableFunctionKind::UserCalled, )) }); @@ -734,7 +745,10 @@ mod tests } } + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; + use crate::provider::r#async::ProvidableFunctionKind; type IUserManagerFactory = dyn Fn(Vec<i128>) -> TransientPtr<dyn IUserManager> + Send + Sync; @@ -754,10 +768,9 @@ mod tests }; inner_mock_provider.expect_provide().returning(|_, _| { - Ok(AsyncProvidable::Factory( - crate::ptr::ThreadsafeFactoryPtr::new( - ThreadsafeCastableFunction::new(factory_func), - ), + Ok(AsyncProvidable::Function( + Arc::new(ThreadsafeCastableFunction::new(factory_func)), + ProvidableFunctionKind::UserCalled, )) }); diff --git a/src/di_container/asynchronous/binding/builder.rs b/src/di_container/asynchronous/binding/builder.rs index 8465c9a..833517b 100644 --- a/src/di_container/asynchronous/binding/builder.rs +++ b/src/di_container/asynchronous/binding/builder.rs @@ -173,8 +173,10 @@ where Interface: Fn<Args, Output = Return> + Send + Sync, FactoryFunc: Fn(&AsyncDIContainer) -> BoxFn<Args, Return> + Send + Sync, { + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; - use crate::provider::r#async::AsyncFactoryVariant; + use crate::provider::r#async::ProvidableFunctionKind; if self .di_container @@ -190,9 +192,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::r#async::AsyncFactoryProvider::new( - crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - AsyncFactoryVariant::Normal, + Box::new(crate::provider::r#async::AsyncFunctionProvider::new( + Arc::new(factory_impl), + ProvidableFunctionKind::UserCalled, )), ); @@ -270,8 +272,10 @@ where + Send + Sync, { + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; - use crate::provider::r#async::AsyncFactoryVariant; + use crate::provider::r#async::ProvidableFunctionKind; if self .di_container @@ -287,9 +291,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::r#async::AsyncFactoryProvider::new( - crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - AsyncFactoryVariant::Normal, + Box::new(crate::provider::r#async::AsyncFunctionProvider::new( + Arc::new(factory_impl), + ProvidableFunctionKind::UserCalled, )), ); @@ -354,8 +358,10 @@ where + Send + Sync, { + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; - use crate::provider::r#async::AsyncFactoryVariant; + use crate::provider::r#async::ProvidableFunctionKind; if self .di_container @@ -371,9 +377,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::r#async::AsyncFactoryProvider::new( - crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - AsyncFactoryVariant::Default, + Box::new(crate::provider::r#async::AsyncFunctionProvider::new( + Arc::new(factory_impl), + ProvidableFunctionKind::Instant, )), ); @@ -445,8 +451,10 @@ where + Send + Sync, { + use std::sync::Arc; + use crate::castable_function::threadsafe::ThreadsafeCastableFunction; - use crate::provider::r#async::AsyncFactoryVariant; + use crate::provider::r#async::ProvidableFunctionKind; if self .di_container @@ -462,9 +470,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::r#async::AsyncFactoryProvider::new( - crate::ptr::ThreadsafeFactoryPtr::new(factory_impl), - AsyncFactoryVariant::AsyncDefault, + Box::new(crate::provider::r#async::AsyncFunctionProvider::new( + Arc::new(factory_impl), + ProvidableFunctionKind::AsyncInstant, )), ); diff --git a/src/di_container/blocking.rs b/src/di_container/blocking.rs index d8b0d59..fa3523b 100644 --- a/src/di_container/blocking.rs +++ b/src/di_container/blocking.rs @@ -284,10 +284,13 @@ impl DIContainer })?, )), #[cfg(feature = "factory")] - Providable::Factory(factory_binding) => { + Providable::Function( + func_bound, + crate::provider::blocking::ProvidableFunctionKind::UserCalled, + ) => { use crate::castable_function::CastableFunction; - let factory = factory_binding + let factory = func_bound .as_any() .downcast_ref::<CastableFunction<Interface, Self>>() .ok_or_else(|| DIContainerError::CastFailed { @@ -298,16 +301,19 @@ impl DIContainer Ok(SomePtr::Factory(factory.call(self).into())) } #[cfg(feature = "factory")] - Providable::DefaultFactory(factory_binding) => { + Providable::Function( + func_bound, + crate::provider::blocking::ProvidableFunctionKind::Instant, + ) => { use crate::castable_function::CastableFunction; use crate::ptr::TransientPtr; - type DefaultFactoryFn<Interface> = + type Func<Interface> = CastableFunction<dyn Fn() -> TransientPtr<Interface>, DIContainer>; - let default_factory = factory_binding + let default_factory = func_bound .as_any() - .downcast_ref::<DefaultFactoryFn<Interface>>() + .downcast_ref::<Func<Interface>>() .ok_or_else(|| DIContainerError::CastFailed { interface: type_name::<Interface>(), binding_kind: "default factory", @@ -517,7 +523,10 @@ mod tests #[cfg(feature = "factory")] fn can_get_factory() { + use std::rc::Rc; + use crate::castable_function::CastableFunction; + use crate::provider::blocking::ProvidableFunctionKind; use crate::ptr::FactoryPtr; trait IUserManager @@ -572,9 +581,10 @@ mod tests let mut mock_provider = MockIProvider::new(); mock_provider.expect_provide().returning_st(|_, _| { - Ok(Providable::Factory(FactoryPtr::new(CastableFunction::new( - factory_func, - )))) + Ok(Providable::Function( + Rc::new(CastableFunction::new(factory_func)), + ProvidableFunctionKind::UserCalled, + )) }); di_container @@ -592,8 +602,10 @@ mod tests #[cfg(feature = "factory")] fn can_get_factory_named() { + use std::rc::Rc; + use crate::castable_function::CastableFunction; - use crate::ptr::FactoryPtr; + use crate::provider::blocking::ProvidableFunctionKind; trait IUserManager { @@ -647,9 +659,10 @@ mod tests let mut mock_provider = MockIProvider::new(); mock_provider.expect_provide().returning_st(|_, _| { - Ok(Providable::Factory(FactoryPtr::new(CastableFunction::new( - factory_func, - )))) + Ok(Providable::Function( + Rc::new(CastableFunction::new(factory_func)), + ProvidableFunctionKind::UserCalled, + )) }); di_container.binding_storage.set::<IUserManagerFactory>( diff --git a/src/di_container/blocking/binding/builder.rs b/src/di_container/blocking/binding/builder.rs index ead1a54..345fb02 100644 --- a/src/di_container/blocking/binding/builder.rs +++ b/src/di_container/blocking/binding/builder.rs @@ -181,7 +181,10 @@ where Interface: Fn<Args, Output = crate::ptr::TransientPtr<Return>>, Func: Fn(&DIContainer) -> Box<Interface>, { + use std::rc::Rc; + use crate::castable_function::CastableFunction; + use crate::provider::blocking::ProvidableFunctionKind; if self .di_container @@ -196,9 +199,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::blocking::FactoryProvider::new( - crate::ptr::FactoryPtr::new(factory_impl), - false, + Box::new(crate::provider::blocking::FunctionProvider::new( + Rc::new(factory_impl), + ProvidableFunctionKind::UserCalled, )), ); @@ -269,7 +272,10 @@ where dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>, >, { + use std::rc::Rc; + use crate::castable_function::CastableFunction; + use crate::provider::blocking::ProvidableFunctionKind; if self .di_container @@ -284,9 +290,9 @@ where self.di_container.set_binding::<Interface>( BindingOptions::new(), - Box::new(crate::provider::blocking::FactoryProvider::new( - crate::ptr::FactoryPtr::new(factory_impl), - true, + Box::new(crate::provider::blocking::FunctionProvider::new( + Rc::new(factory_impl), + ProvidableFunctionKind::Instant, )), ); diff --git a/src/provider/async.rs b/src/provider/async.rs index 68eed87..b011d7a 100644 --- a/src/provider/async.rs +++ b/src/provider/async.rs @@ -15,25 +15,23 @@ pub enum AsyncProvidable<DIContainerT> Transient(TransientPtr<dyn AsyncInjectable<DIContainerT>>), Singleton(ThreadsafeSingletonPtr<dyn AsyncInjectable<DIContainerT>>), #[cfg(feature = "factory")] - Factory( - crate::ptr::ThreadsafeFactoryPtr< - dyn crate::castable_function::threadsafe::AnyThreadsafeCastableFunction, - >, - ), - #[cfg(feature = "factory")] - DefaultFactory( - crate::ptr::ThreadsafeFactoryPtr< - dyn crate::castable_function::threadsafe::AnyThreadsafeCastableFunction, - >, - ), - #[cfg(feature = "factory")] - AsyncDefaultFactory( - crate::ptr::ThreadsafeFactoryPtr< + Function( + std::sync::Arc< dyn crate::castable_function::threadsafe::AnyThreadsafeCastableFunction, >, + ProvidableFunctionKind, ), } +#[cfg(feature = "factory")] +#[derive(Debug, Clone, Copy)] +pub enum ProvidableFunctionKind +{ + UserCalled, + Instant, + AsyncInstant, +} + #[async_trait] #[cfg_attr(test, mockall::automock, allow(dead_code))] pub trait IAsyncProvider<DIContainerT>: Send + Sync @@ -177,40 +175,34 @@ where } #[cfg(feature = "factory")] -#[derive(Clone, Copy)] -pub enum AsyncFactoryVariant -{ - Normal, - Default, - AsyncDefault, -} - -#[cfg(feature = "factory")] -pub struct AsyncFactoryProvider +pub struct AsyncFunctionProvider { - factory: crate::ptr::ThreadsafeFactoryPtr< + function: std::sync::Arc< dyn crate::castable_function::threadsafe::AnyThreadsafeCastableFunction, >, - variant: AsyncFactoryVariant, + providable_func_kind: ProvidableFunctionKind, } #[cfg(feature = "factory")] -impl AsyncFactoryProvider +impl AsyncFunctionProvider { pub fn new( - factory: crate::ptr::ThreadsafeFactoryPtr< + function: std::sync::Arc< dyn crate::castable_function::threadsafe::AnyThreadsafeCastableFunction, >, - variant: AsyncFactoryVariant, + providable_func_kind: ProvidableFunctionKind, ) -> Self { - Self { factory, variant } + Self { + function, + providable_func_kind, + } } } #[cfg(feature = "factory")] #[async_trait] -impl<DIContainerT> IAsyncProvider<DIContainerT> for AsyncFactoryProvider +impl<DIContainerT> IAsyncProvider<DIContainerT> for AsyncFunctionProvider where DIContainerT: Send + Sync, { @@ -220,15 +212,10 @@ where _dependency_history: DependencyHistory, ) -> Result<AsyncProvidable<DIContainerT>, InjectableError> { - Ok(match self.variant { - AsyncFactoryVariant::Normal => AsyncProvidable::Factory(self.factory.clone()), - AsyncFactoryVariant::Default => { - AsyncProvidable::DefaultFactory(self.factory.clone()) - } - AsyncFactoryVariant::AsyncDefault => { - AsyncProvidable::AsyncDefaultFactory(self.factory.clone()) - } - }) + Ok(AsyncProvidable::Function( + self.function.clone(), + self.providable_func_kind, + )) } fn do_clone(&self) -> Box<dyn IAsyncProvider<DIContainerT>> @@ -238,13 +225,13 @@ where } #[cfg(feature = "factory")] -impl Clone for AsyncFactoryProvider +impl Clone for AsyncFunctionProvider { fn clone(&self) -> Self { Self { - factory: self.factory.clone(), - variant: self.variant, + function: self.function.clone(), + providable_func_kind: self.providable_func_kind, } } } @@ -305,13 +292,13 @@ mod tests #[tokio::test] #[cfg(feature = "factory")] - async fn async_factory_provider_works() + async fn function_provider_works() { use std::any::Any; + use std::sync::Arc; use crate::castable_function::threadsafe::AnyThreadsafeCastableFunction; use crate::castable_function::AnyCastableFunction; - use crate::ptr::ThreadsafeFactoryPtr; #[derive(Debug)] struct FooFactory; @@ -326,54 +313,63 @@ mod tests impl AnyThreadsafeCastableFunction for FooFactory {} - let factory_provider = AsyncFactoryProvider::new( - ThreadsafeFactoryPtr::new(FooFactory), - AsyncFactoryVariant::Normal, + let user_called_func_provider = AsyncFunctionProvider::new( + Arc::new(FooFactory), + ProvidableFunctionKind::UserCalled, ); - let default_factory_provider = AsyncFactoryProvider::new( - ThreadsafeFactoryPtr::new(FooFactory), - AsyncFactoryVariant::Default, + let instant_func_provider = AsyncFunctionProvider::new( + Arc::new(FooFactory), + ProvidableFunctionKind::Instant, ); - let async_default_factory_provider = AsyncFactoryProvider::new( - ThreadsafeFactoryPtr::new(FooFactory), - AsyncFactoryVariant::AsyncDefault, + let async_instant_func_provider = AsyncFunctionProvider::new( + Arc::new(FooFactory), + ProvidableFunctionKind::AsyncInstant, ); let di_container = MockAsyncDIContainer::new(); assert!( matches!( - factory_provider + user_called_func_provider .provide(&di_container, MockDependencyHistory::new()) .await .unwrap(), - AsyncProvidable::Factory(_) + AsyncProvidable::Function(_, ProvidableFunctionKind::UserCalled) ), - "The provided type is not a factory" + concat!( + "The provided type is not a AsyncProvidable::Function of kind ", + "ProvidableFunctionKind::UserCalled" + ) ); assert!( matches!( - default_factory_provider + instant_func_provider .provide(&di_container, MockDependencyHistory::new()) .await .unwrap(), - AsyncProvidable::DefaultFactory(_) + AsyncProvidable::Function(_, ProvidableFunctionKind::Instant) ), - "The provided type is not a default factory" + concat!( + "The provided type is not a AsyncProvidable::Function of kind ", + "ProvidableFunctionKind::Instant" + ) ); assert!( matches!( - async_default_factory_provider + async_instant_func_provider .provide(&di_container, MockDependencyHistory::new()) .await .unwrap(), - AsyncProvidable::AsyncDefaultFactory(_) + AsyncProvidable::Function(_, ProvidableFunctionKind::AsyncInstant) ), - "The provided type is not a async default factory" + concat!( + "The provided type is not a AsyncProvidable::Function of kind ", + "ProvidableFunctionKind::AsyncInstant" + ) ); } } diff --git a/src/provider/blocking.rs b/src/provider/blocking.rs index 6475dc7..e7f113b 100644 --- a/src/provider/blocking.rs +++ b/src/provider/blocking.rs @@ -13,13 +13,20 @@ pub enum Providable<DIContainerType> Transient(TransientPtr<dyn Injectable<DIContainerType>>), Singleton(SingletonPtr<dyn Injectable<DIContainerType>>), #[cfg(feature = "factory")] - Factory(crate::ptr::FactoryPtr<dyn crate::castable_function::AnyCastableFunction>), - #[cfg(feature = "factory")] - DefaultFactory( - crate::ptr::FactoryPtr<dyn crate::castable_function::AnyCastableFunction>, + Function( + std::rc::Rc<dyn crate::castable_function::AnyCastableFunction>, + ProvidableFunctionKind, ), } +#[cfg(feature = "factory")] +#[derive(Debug, Clone, Copy)] +pub enum ProvidableFunctionKind +{ + Instant, + UserCalled, +} + #[cfg_attr(test, mockall::automock)] pub trait IProvider<DIContainerType> { @@ -108,31 +115,29 @@ where } #[cfg(feature = "factory")] -pub struct FactoryProvider +pub struct FunctionProvider { - factory: crate::ptr::FactoryPtr<dyn crate::castable_function::AnyCastableFunction>, - is_default_factory: bool, + function: std::rc::Rc<dyn crate::castable_function::AnyCastableFunction>, + providable_func_kind: ProvidableFunctionKind, } #[cfg(feature = "factory")] -impl FactoryProvider +impl FunctionProvider { pub fn new( - factory: crate::ptr::FactoryPtr< - dyn crate::castable_function::AnyCastableFunction, - >, - is_default_factory: bool, + function: std::rc::Rc<dyn crate::castable_function::AnyCastableFunction>, + providable_func_kind: ProvidableFunctionKind, ) -> Self { Self { - factory, - is_default_factory, + function, + providable_func_kind, } } } #[cfg(feature = "factory")] -impl<DIContainerType> IProvider<DIContainerType> for FactoryProvider +impl<DIContainerType> IProvider<DIContainerType> for FunctionProvider { fn provide( &self, @@ -140,11 +145,10 @@ impl<DIContainerType> IProvider<DIContainerType> for FactoryProvider _dependency_history: DependencyHistory, ) -> Result<Providable<DIContainerType>, InjectableError> { - Ok(if self.is_default_factory { - Providable::DefaultFactory(self.factory.clone()) - } else { - Providable::Factory(self.factory.clone()) - }) + Ok(Providable::Function( + self.function.clone(), + self.providable_func_kind, + )) } } @@ -198,12 +202,12 @@ mod tests #[test] #[cfg(feature = "factory")] - fn factory_provider_works() + fn function_provider_works() { use std::any::Any; + use std::rc::Rc; use crate::castable_function::AnyCastableFunction; - use crate::ptr::FactoryPtr; #[derive(Debug)] struct FooFactory; @@ -216,27 +220,38 @@ mod tests } } - let factory_provider = FactoryProvider::new(FactoryPtr::new(FooFactory), false); - let default_factory_provider = - FactoryProvider::new(FactoryPtr::new(FooFactory), true); + let user_called_func_provider = FunctionProvider::new( + Rc::new(FooFactory), + ProvidableFunctionKind::UserCalled, + ); + + let instant_func_provider = + FunctionProvider::new(Rc::new(FooFactory), ProvidableFunctionKind::Instant); let di_container = MockDIContainer::new(); assert!( matches!( - factory_provider.provide(&di_container, MockDependencyHistory::new()), - Ok(Providable::Factory(_)) + user_called_func_provider + .provide(&di_container, MockDependencyHistory::new()), + Ok(Providable::Function(_, ProvidableFunctionKind::UserCalled)) ), - "The provided type is not a factory" + concat!( + "The provided type is not a Providable::Function of kind ", + "ProvidableFunctionKind::UserCalled" + ) ); assert!( matches!( - default_factory_provider + instant_func_provider .provide(&di_container, MockDependencyHistory::new()), - Ok(Providable::DefaultFactory(_)) + Ok(Providable::Function(_, ProvidableFunctionKind::Instant)) ), - "The provided type is not a default factory" + concat!( + "The provided type is not a Providable::Function of kind ", + "ProvidableFunctionKind::Instant" + ) ); } } |