diff options
Diffstat (limited to 'src/di_container.rs')
-rw-r--r-- | src/di_container.rs | 272 |
1 files changed, 245 insertions, 27 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index d1c757d..e59d8f4 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -7,11 +7,11 @@ use error_stack::{Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; -use crate::errors::di_container::DIContainerError; +use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; -use crate::libs::intertrait::cast::CastBox; -use crate::provider::{IProvider, InjectableTypeProvider, Providable}; -use crate::ptr::TransientPtr; +use crate::libs::intertrait::cast::{CastBox, CastRc}; +use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::ptr::{SingletonPtr, TransientPtr}; /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> @@ -44,10 +44,35 @@ where self.di_container.bindings.insert( interface_typeid, - Rc::new(InjectableTypeProvider::<Implementation>::new()), + Rc::new(TransientTypeProvider::<Implementation>::new()), ); } + /// Creates a binding of type `Interface` to a new singleton of type `Implementation` + /// inside of the associated [`DIContainer`]. + /// + /// # Errors + /// Will return Err if creating the singleton fails. + pub fn to_singleton<Implementation>( + &mut self, + ) -> error_stack::Result<(), BindingBuilderError> + where + Implementation: Injectable, + { + let interface_typeid = TypeId::of::<Interface>(); + + let singleton: SingletonPtr<Implementation> = SingletonPtr::from( + Implementation::resolve(self.di_container) + .change_context(BindingBuilderError)?, + ); + + self.di_container + .bindings + .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + + Ok(()) + } + /// Creates a binding of factory type `Interface` to a factory inside of the /// associated [`DIContainer`]. #[cfg(feature = "factory")] @@ -173,21 +198,69 @@ impl DIContainer ))?; match binding_providable { - Providable::Injectable(binding_injectable) => { - let interface_box_result = binding_injectable.cast::<Interface>(); + Providable::Transient(binding_injectable) => { + let interface_result = binding_injectable.cast::<Interface>(); - match interface_box_result { - Ok(interface_box) => Ok(interface_box), + match interface_result { + Ok(interface) => Ok(interface), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Factory(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not injectable", + interface_name + ))), + } + } + + /// Returns the singleton instance bound with `Interface`. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` exists + /// - Resolving the binding for `Interface` fails + /// - Casting the binding for `Interface` fails + /// - The binding for `Interface` is not a singleton + pub fn get_singleton<Interface>( + &self, + ) -> error_stack::Result<SingletonPtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + let interface_name = type_name::<Interface>(); + + let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { + Report::new(DIContainerError) + .attach_printable(format!("No binding exists for {}", interface_name)) + })?; + + let binding_providable = binding + .provide(self) + .change_context(DIContainerError) + .attach_printable(format!( + "Failed to resolve binding for interface {}", + interface_name + ))?; + + match binding_providable { + Providable::Singleton(binding_singleton) => { + let interface_result = binding_singleton.cast::<Interface>(); + + match interface_result { + Ok(interface) => Ok(interface), + Err(_) => Err(Report::new(DIContainerError).attach_printable( + format!("Unable to cast binding for {}", interface_name), + )), + } + } + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a singleton", + interface_name + ))), } } @@ -225,22 +298,19 @@ impl DIContainer match binding_providable { Providable::Factory(binding_factory) => { - use crate::libs::intertrait::cast::CastRc; + let factory_result = binding_factory.cast::<Interface>(); - let factory_box_result = binding_factory.cast::<Interface>(); - - match factory_box_result { - Ok(interface_box) => Ok(interface_box), + match factory_result { + Ok(factory) => Ok(factory), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Injectable(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a factory", + interface_name + ))), } } } @@ -256,6 +326,8 @@ impl Default for DIContainer #[cfg(test)] mod tests { + use std::fmt::Debug; + use mockall::mock; use super::*; @@ -312,6 +384,59 @@ mod tests } #[test] + fn can_bind_to_singleton() -> error_stack::Result<(), BindingBuilderError> + { + trait IUserManager + { + fn add_user(&self, user_id: i128); + + fn remove_user(&self, user_id: i128); + } + + struct UserManager {} + + impl IUserManager for UserManager + { + fn add_user(&self, _user_id: i128) + { + // ... + } + + fn remove_user(&self, _user_id: i128) + { + // ... + } + } + + impl Injectable for UserManager + { + fn resolve( + _di_container: &DIContainer, + ) -> error_stack::Result< + TransientPtr<Self>, + crate::errors::injectable::ResolveError, + > + where + Self: Sized, + { + Ok(TransientPtr::new(Self {})) + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.len(), 0); + + di_container + .bind::<dyn IUserManager>() + .to_singleton::<UserManager>()?; + + assert_eq!(di_container.bindings.len(), 1); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() { @@ -416,9 +541,7 @@ mod tests let mut mock_provider = MockProvider::new(); mock_provider.expect_provide().returning(|_| { - Ok(Providable::Injectable( - TransientPtr::new(UserManager::new()), - )) + Ok(Providable::Transient(TransientPtr::new(UserManager::new()))) }); di_container @@ -431,6 +554,101 @@ mod tests } #[test] + fn can_get_singleton() -> error_stack::Result<(), DIContainerError> + { + trait INumber + { + fn get(&self) -> i32; + + fn set(&mut self, number: i32); + } + + impl PartialEq for dyn INumber + { + fn eq(&self, other: &Self) -> bool + { + self.get() == other.get() + } + } + + impl Debug for dyn INumber + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str(format!("{}", self.get()).as_str()) + } + } + + struct Number + { + num: i32, + } + + use crate as syrette; + use crate::injectable; + + #[injectable(INumber)] + impl Number + { + fn new() -> Self + { + Self { num: 0 } + } + } + + impl INumber for Number + { + fn get(&self) -> i32 + { + self.num + } + + fn set(&mut self, number: i32) + { + self.num = number; + } + } + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + ) -> error_stack::Result<Providable, ResolveError>; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().set(2820); + + mock_provider + .expect_provide() + .returning_st(move |_| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider)); + + let first_number_rc = di_container.get_singleton::<dyn INumber>()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container.get_singleton::<dyn INumber>()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> error_stack::Result<(), DIContainerError> { |