diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/di_container.rs | 272 | ||||
-rw-r--r-- | src/errors/di_container.rs | 13 | ||||
-rw-r--r-- | src/provider.rs | 43 | ||||
-rw-r--r-- | src/ptr.rs | 2 |
4 files changed, 297 insertions, 33 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index d1c757d..e59d8f4 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -7,11 +7,11 @@ use error_stack::{Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; -use crate::errors::di_container::DIContainerError; +use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; -use crate::libs::intertrait::cast::CastBox; -use crate::provider::{IProvider, InjectableTypeProvider, Providable}; -use crate::ptr::TransientPtr; +use crate::libs::intertrait::cast::{CastBox, CastRc}; +use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::ptr::{SingletonPtr, TransientPtr}; /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> @@ -44,10 +44,35 @@ where self.di_container.bindings.insert( interface_typeid, - Rc::new(InjectableTypeProvider::<Implementation>::new()), + Rc::new(TransientTypeProvider::<Implementation>::new()), ); } + /// Creates a binding of type `Interface` to a new singleton of type `Implementation` + /// inside of the associated [`DIContainer`]. + /// + /// # Errors + /// Will return Err if creating the singleton fails. + pub fn to_singleton<Implementation>( + &mut self, + ) -> error_stack::Result<(), BindingBuilderError> + where + Implementation: Injectable, + { + let interface_typeid = TypeId::of::<Interface>(); + + let singleton: SingletonPtr<Implementation> = SingletonPtr::from( + Implementation::resolve(self.di_container) + .change_context(BindingBuilderError)?, + ); + + self.di_container + .bindings + .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + + Ok(()) + } + /// Creates a binding of factory type `Interface` to a factory inside of the /// associated [`DIContainer`]. #[cfg(feature = "factory")] @@ -173,21 +198,69 @@ impl DIContainer ))?; match binding_providable { - Providable::Injectable(binding_injectable) => { - let interface_box_result = binding_injectable.cast::<Interface>(); + Providable::Transient(binding_injectable) => { + let interface_result = binding_injectable.cast::<Interface>(); - match interface_box_result { - Ok(interface_box) => Ok(interface_box), + match interface_result { + Ok(interface) => Ok(interface), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Factory(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not injectable", + interface_name + ))), + } + } + + /// Returns the singleton instance bound with `Interface`. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` exists + /// - Resolving the binding for `Interface` fails + /// - Casting the binding for `Interface` fails + /// - The binding for `Interface` is not a singleton + pub fn get_singleton<Interface>( + &self, + ) -> error_stack::Result<SingletonPtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + let interface_name = type_name::<Interface>(); + + let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { + Report::new(DIContainerError) + .attach_printable(format!("No binding exists for {}", interface_name)) + })?; + + let binding_providable = binding + .provide(self) + .change_context(DIContainerError) + .attach_printable(format!( + "Failed to resolve binding for interface {}", + interface_name + ))?; + + match binding_providable { + Providable::Singleton(binding_singleton) => { + let interface_result = binding_singleton.cast::<Interface>(); + + match interface_result { + Ok(interface) => Ok(interface), + Err(_) => Err(Report::new(DIContainerError).attach_printable( + format!("Unable to cast binding for {}", interface_name), + )), + } + } + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a singleton", + interface_name + ))), } } @@ -225,22 +298,19 @@ impl DIContainer match binding_providable { Providable::Factory(binding_factory) => { - use crate::libs::intertrait::cast::CastRc; + let factory_result = binding_factory.cast::<Interface>(); - let factory_box_result = binding_factory.cast::<Interface>(); - - match factory_box_result { - Ok(interface_box) => Ok(interface_box), + match factory_result { + Ok(factory) => Ok(factory), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Injectable(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a factory", + interface_name + ))), } } } @@ -256,6 +326,8 @@ impl Default for DIContainer #[cfg(test)] mod tests { + use std::fmt::Debug; + use mockall::mock; use super::*; @@ -312,6 +384,59 @@ mod tests } #[test] + fn can_bind_to_singleton() -> error_stack::Result<(), BindingBuilderError> + { + trait IUserManager + { + fn add_user(&self, user_id: i128); + + fn remove_user(&self, user_id: i128); + } + + struct UserManager {} + + impl IUserManager for UserManager + { + fn add_user(&self, _user_id: i128) + { + // ... + } + + fn remove_user(&self, _user_id: i128) + { + // ... + } + } + + impl Injectable for UserManager + { + fn resolve( + _di_container: &DIContainer, + ) -> error_stack::Result< + TransientPtr<Self>, + crate::errors::injectable::ResolveError, + > + where + Self: Sized, + { + Ok(TransientPtr::new(Self {})) + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.len(), 0); + + di_container + .bind::<dyn IUserManager>() + .to_singleton::<UserManager>()?; + + assert_eq!(di_container.bindings.len(), 1); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() { @@ -416,9 +541,7 @@ mod tests let mut mock_provider = MockProvider::new(); mock_provider.expect_provide().returning(|_| { - Ok(Providable::Injectable( - TransientPtr::new(UserManager::new()), - )) + Ok(Providable::Transient(TransientPtr::new(UserManager::new()))) }); di_container @@ -431,6 +554,101 @@ mod tests } #[test] + fn can_get_singleton() -> error_stack::Result<(), DIContainerError> + { + trait INumber + { + fn get(&self) -> i32; + + fn set(&mut self, number: i32); + } + + impl PartialEq for dyn INumber + { + fn eq(&self, other: &Self) -> bool + { + self.get() == other.get() + } + } + + impl Debug for dyn INumber + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str(format!("{}", self.get()).as_str()) + } + } + + struct Number + { + num: i32, + } + + use crate as syrette; + use crate::injectable; + + #[injectable(INumber)] + impl Number + { + fn new() -> Self + { + Self { num: 0 } + } + } + + impl INumber for Number + { + fn get(&self) -> i32 + { + self.num + } + + fn set(&mut self, number: i32) + { + self.num = number; + } + } + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + ) -> error_stack::Result<Providable, ResolveError>; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().set(2820); + + mock_provider + .expect_provide() + .returning_st(move |_| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider)); + + let first_number_rc = di_container.get_singleton::<dyn INumber>()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container.get_singleton::<dyn INumber>()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> error_stack::Result<(), DIContainerError> { diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index f2b8add..3b8c717 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -15,3 +15,16 @@ impl Display for DIContainerError } impl Context for DIContainerError {} + +#[derive(Debug)] +pub struct BindingBuilderError; + +impl Display for BindingBuilderError +{ + fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result + { + fmt.write_str("A binding builder error has occurred") + } +} + +impl Context for BindingBuilderError {} diff --git a/src/provider.rs b/src/provider.rs index eb5b5d8..2e832f8 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -4,14 +4,15 @@ use std::marker::PhantomData; use crate::errors::injectable::ResolveError; use crate::interfaces::any_factory::AnyFactory; use crate::interfaces::injectable::Injectable; -use crate::ptr::{FactoryPtr, TransientPtr}; +use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr}; use crate::DIContainer; extern crate error_stack; pub enum Providable { - Injectable(TransientPtr<dyn Injectable>), + Transient(TransientPtr<dyn Injectable>), + Singleton(SingletonPtr<dyn Injectable>), #[allow(dead_code)] Factory(FactoryPtr<dyn AnyFactory>), } @@ -24,14 +25,14 @@ pub trait IProvider ) -> error_stack::Result<Providable, ResolveError>; } -pub struct InjectableTypeProvider<InjectableType> +pub struct TransientTypeProvider<InjectableType> where InjectableType: Injectable, { injectable_phantom: PhantomData<InjectableType>, } -impl<InjectableType> InjectableTypeProvider<InjectableType> +impl<InjectableType> TransientTypeProvider<InjectableType> where InjectableType: Injectable, { @@ -43,7 +44,7 @@ where } } -impl<InjectableType> IProvider for InjectableTypeProvider<InjectableType> +impl<InjectableType> IProvider for TransientTypeProvider<InjectableType> where InjectableType: Injectable, { @@ -52,12 +53,42 @@ where di_container: &DIContainer, ) -> error_stack::Result<Providable, ResolveError> { - Ok(Providable::Injectable(InjectableType::resolve( + Ok(Providable::Transient(InjectableType::resolve( di_container, )?)) } } +pub struct SingletonProvider<InjectableType> +where + InjectableType: Injectable, +{ + singleton: SingletonPtr<InjectableType>, +} + +impl<InjectableType> SingletonProvider<InjectableType> +where + InjectableType: Injectable, +{ + pub fn new(singleton: SingletonPtr<InjectableType>) -> Self + { + Self { singleton } + } +} + +impl<InjectableType> IProvider for SingletonProvider<InjectableType> +where + InjectableType: Injectable, +{ + fn provide( + &self, + _di_container: &DIContainer, + ) -> error_stack::Result<Providable, ResolveError> + { + Ok(Providable::Singleton(self.singleton.clone())) + } +} + #[cfg(feature = "factory")] pub struct FactoryProvider { @@ -3,4 +3,6 @@ use std::rc::Rc; pub type TransientPtr<Interface> = Box<Interface>; +pub type SingletonPtr<Interface> = Rc<Interface>; + pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>; |