From 3388f857b32cf1893d7b54582c8fd16e4965550b Mon Sep 17 00:00:00 2001 From: HampusM Date: Wed, 27 Jul 2022 18:05:34 +0200 Subject: feat: implement binding singletons --- src/di_container.rs | 272 ++++++++++++++++++++++++++++++++++++++++----- src/errors/di_container.rs | 13 +++ src/provider.rs | 43 ++++++- src/ptr.rs | 2 + 4 files changed, 297 insertions(+), 33 deletions(-) diff --git a/src/di_container.rs b/src/di_container.rs index d1c757d..e59d8f4 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -7,11 +7,11 @@ use error_stack::{Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; -use crate::errors::di_container::DIContainerError; +use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; -use crate::libs::intertrait::cast::CastBox; -use crate::provider::{IProvider, InjectableTypeProvider, Providable}; -use crate::ptr::TransientPtr; +use crate::libs::intertrait::cast::{CastBox, CastRc}; +use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::ptr::{SingletonPtr, TransientPtr}; /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> @@ -44,10 +44,35 @@ where self.di_container.bindings.insert( interface_typeid, - Rc::new(InjectableTypeProvider::::new()), + Rc::new(TransientTypeProvider::::new()), ); } + /// Creates a binding of type `Interface` to a new singleton of type `Implementation` + /// inside of the associated [`DIContainer`]. + /// + /// # Errors + /// Will return Err if creating the singleton fails. + pub fn to_singleton( + &mut self, + ) -> error_stack::Result<(), BindingBuilderError> + where + Implementation: Injectable, + { + let interface_typeid = TypeId::of::(); + + let singleton: SingletonPtr = SingletonPtr::from( + Implementation::resolve(self.di_container) + .change_context(BindingBuilderError)?, + ); + + self.di_container + .bindings + .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + + Ok(()) + } + /// Creates a binding of factory type `Interface` to a factory inside of the /// associated [`DIContainer`]. #[cfg(feature = "factory")] @@ -173,21 +198,69 @@ impl DIContainer ))?; match binding_providable { - Providable::Injectable(binding_injectable) => { - let interface_box_result = binding_injectable.cast::(); + Providable::Transient(binding_injectable) => { + let interface_result = binding_injectable.cast::(); - match interface_box_result { - Ok(interface_box) => Ok(interface_box), + match interface_result { + Ok(interface) => Ok(interface), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Factory(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not injectable", + interface_name + ))), + } + } + + /// Returns the singleton instance bound with `Interface`. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` exists + /// - Resolving the binding for `Interface` fails + /// - Casting the binding for `Interface` fails + /// - The binding for `Interface` is not a singleton + pub fn get_singleton( + &self, + ) -> error_stack::Result, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::(); + + let interface_name = type_name::(); + + let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { + Report::new(DIContainerError) + .attach_printable(format!("No binding exists for {}", interface_name)) + })?; + + let binding_providable = binding + .provide(self) + .change_context(DIContainerError) + .attach_printable(format!( + "Failed to resolve binding for interface {}", + interface_name + ))?; + + match binding_providable { + Providable::Singleton(binding_singleton) => { + let interface_result = binding_singleton.cast::(); + + match interface_result { + Ok(interface) => Ok(interface), + Err(_) => Err(Report::new(DIContainerError).attach_printable( + format!("Unable to cast binding for {}", interface_name), + )), + } + } + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a singleton", + interface_name + ))), } } @@ -225,22 +298,19 @@ impl DIContainer match binding_providable { Providable::Factory(binding_factory) => { - use crate::libs::intertrait::cast::CastRc; + let factory_result = binding_factory.cast::(); - let factory_box_result = binding_factory.cast::(); - - match factory_box_result { - Ok(interface_box) => Ok(interface_box), + match factory_result { + Ok(factory) => Ok(factory), Err(_) => Err(Report::new(DIContainerError).attach_printable( format!("Unable to cast binding for {}", interface_name), )), } } - Providable::Injectable(_) => Err(Report::new(DIContainerError) - .attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), + _ => Err(Report::new(DIContainerError).attach_printable(format!( + "Binding for {} is not a factory", + interface_name + ))), } } } @@ -256,6 +326,8 @@ impl Default for DIContainer #[cfg(test)] mod tests { + use std::fmt::Debug; + use mockall::mock; use super::*; @@ -311,6 +383,59 @@ mod tests assert_eq!(di_container.bindings.len(), 1); } + #[test] + fn can_bind_to_singleton() -> error_stack::Result<(), BindingBuilderError> + { + trait IUserManager + { + fn add_user(&self, user_id: i128); + + fn remove_user(&self, user_id: i128); + } + + struct UserManager {} + + impl IUserManager for UserManager + { + fn add_user(&self, _user_id: i128) + { + // ... + } + + fn remove_user(&self, _user_id: i128) + { + // ... + } + } + + impl Injectable for UserManager + { + fn resolve( + _di_container: &DIContainer, + ) -> error_stack::Result< + TransientPtr, + crate::errors::injectable::ResolveError, + > + where + Self: Sized, + { + Ok(TransientPtr::new(Self {})) + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.len(), 0); + + di_container + .bind::() + .to_singleton::()?; + + assert_eq!(di_container.bindings.len(), 1); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() @@ -416,9 +541,7 @@ mod tests let mut mock_provider = MockProvider::new(); mock_provider.expect_provide().returning(|_| { - Ok(Providable::Injectable( - TransientPtr::new(UserManager::new()), - )) + Ok(Providable::Transient(TransientPtr::new(UserManager::new()))) }); di_container @@ -430,6 +553,101 @@ mod tests Ok(()) } + #[test] + fn can_get_singleton() -> error_stack::Result<(), DIContainerError> + { + trait INumber + { + fn get(&self) -> i32; + + fn set(&mut self, number: i32); + } + + impl PartialEq for dyn INumber + { + fn eq(&self, other: &Self) -> bool + { + self.get() == other.get() + } + } + + impl Debug for dyn INumber + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + f.write_str(format!("{}", self.get()).as_str()) + } + } + + struct Number + { + num: i32, + } + + use crate as syrette; + use crate::injectable; + + #[injectable(INumber)] + impl Number + { + fn new() -> Self + { + Self { num: 0 } + } + } + + impl INumber for Number + { + fn get(&self) -> i32 + { + self.num + } + + fn set(&mut self, number: i32) + { + self.num = number; + } + } + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + ) -> error_stack::Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().set(2820); + + mock_provider + .expect_provide() + .returning_st(move |_| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .insert(TypeId::of::(), Rc::new(mock_provider)); + + let first_number_rc = di_container.get_singleton::()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container.get_singleton::()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> error_stack::Result<(), DIContainerError> diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index f2b8add..3b8c717 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -15,3 +15,16 @@ impl Display for DIContainerError } impl Context for DIContainerError {} + +#[derive(Debug)] +pub struct BindingBuilderError; + +impl Display for BindingBuilderError +{ + fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result + { + fmt.write_str("A binding builder error has occurred") + } +} + +impl Context for BindingBuilderError {} diff --git a/src/provider.rs b/src/provider.rs index eb5b5d8..2e832f8 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -4,14 +4,15 @@ use std::marker::PhantomData; use crate::errors::injectable::ResolveError; use crate::interfaces::any_factory::AnyFactory; use crate::interfaces::injectable::Injectable; -use crate::ptr::{FactoryPtr, TransientPtr}; +use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr}; use crate::DIContainer; extern crate error_stack; pub enum Providable { - Injectable(TransientPtr), + Transient(TransientPtr), + Singleton(SingletonPtr), #[allow(dead_code)] Factory(FactoryPtr), } @@ -24,14 +25,14 @@ pub trait IProvider ) -> error_stack::Result; } -pub struct InjectableTypeProvider +pub struct TransientTypeProvider where InjectableType: Injectable, { injectable_phantom: PhantomData, } -impl InjectableTypeProvider +impl TransientTypeProvider where InjectableType: Injectable, { @@ -43,7 +44,7 @@ where } } -impl IProvider for InjectableTypeProvider +impl IProvider for TransientTypeProvider where InjectableType: Injectable, { @@ -52,12 +53,42 @@ where di_container: &DIContainer, ) -> error_stack::Result { - Ok(Providable::Injectable(InjectableType::resolve( + Ok(Providable::Transient(InjectableType::resolve( di_container, )?)) } } +pub struct SingletonProvider +where + InjectableType: Injectable, +{ + singleton: SingletonPtr, +} + +impl SingletonProvider +where + InjectableType: Injectable, +{ + pub fn new(singleton: SingletonPtr) -> Self + { + Self { singleton } + } +} + +impl IProvider for SingletonProvider +where + InjectableType: Injectable, +{ + fn provide( + &self, + _di_container: &DIContainer, + ) -> error_stack::Result + { + Ok(Providable::Singleton(self.singleton.clone())) + } +} + #[cfg(feature = "factory")] pub struct FactoryProvider { diff --git a/src/ptr.rs b/src/ptr.rs index 62445fd..00c74f4 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -3,4 +3,6 @@ use std::rc::Rc; pub type TransientPtr = Box; +pub type SingletonPtr = Rc; + pub type FactoryPtr = Rc; -- cgit v1.2.3-18-g5258