From 94b8e52eeefdabab98dfdb5bf4c91f75d778150c Mon Sep 17 00:00:00 2001 From: HampusM Date: Thu, 28 Jul 2022 20:38:33 +0200 Subject: refactor: tidy up DI container internals --- src/di_container.rs | 199 ++++++++++++++++------------------------ src/di_container_binding_map.rs | 55 +++++++++++ src/lib.rs | 1 + 3 files changed, 134 insertions(+), 121 deletions(-) create mode 100644 src/di_container_binding_map.rs (limited to 'src') diff --git a/src/di_container.rs b/src/di_container.rs index e59d8f4..e18fc3a 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -1,18 +1,29 @@ -use std::any::{type_name, TypeId}; +use std::any::type_name; use std::marker::PhantomData; -use std::rc::Rc; -use ahash::AHashMap; -use error_stack::{Report, ResultExt}; +use error_stack::{report, Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; +use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; +use crate::libs::intertrait::cast::error::CastError; use crate::libs::intertrait::cast::{CastBox, CastRc}; -use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, TransientPtr}; +fn unable_to_cast_binding(err: Report) -> Report +where + Interface: 'static + ?Sized, +{ + err.change_context(DIContainerError) + .attach_printable(format!( + "Unable to cast binding for interface '{}'", + type_name::() + )) +} + /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> where @@ -40,12 +51,9 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::(); - - self.di_container.bindings.insert( - interface_typeid, - Rc::new(TransientTypeProvider::::new()), - ); + self.di_container + .bindings + .set::(Box::new(TransientTypeProvider::::new())); } /// Creates a binding of type `Interface` to a new singleton of type `Implementation` @@ -59,8 +67,6 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::(); - let singleton: SingletonPtr = SingletonPtr::from( Implementation::resolve(self.di_container) .change_context(BindingBuilderError)?, @@ -68,7 +74,7 @@ where self.di_container .bindings - .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + .set::(Box::new(SingletonProvider::new(singleton))); Ok(()) } @@ -84,16 +90,13 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory, { - let interface_typeid = TypeId::of::(); - let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.insert( - interface_typeid, - Rc::new(crate::provider::FactoryProvider::new( - crate::ptr::FactoryPtr::new(factory_impl), + self.di_container.bindings.set::(Box::new( + crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( + factory_impl, )), - ); + )); } } @@ -144,7 +147,7 @@ where /// ``` pub struct DIContainer { - bindings: AHashMap>, + bindings: DIContainerBindingMap, } impl DIContainer @@ -154,7 +157,7 @@ impl DIContainer pub fn new() -> Self { Self { - bindings: AHashMap::new(), + bindings: DIContainerBindingMap::new(), } } @@ -180,39 +183,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::(); - - let interface_name = type_name::(); + let binding_providable = self.get_binding_providable::()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Transient(binding_injectable) => { - let interface_result = binding_injectable.cast::(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + if let Providable::Transient(binding_transient) = binding_providable { + return binding_transient + .cast::() + .map_err(unable_to_cast_binding::); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not transient", + type_name::() + ))) } /// Returns the singleton instance bound with `Interface`. @@ -229,39 +211,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::(); - - let interface_name = type_name::(); + let binding_providable = self.get_binding_providable::()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Singleton(binding_singleton) => { - let interface_result = binding_singleton.cast::(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a singleton", - interface_name - ))), + if let Providable::Singleton(binding_singleton) = binding_providable { + return binding_singleton + .cast::() + .map_err(unable_to_cast_binding::); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a singleton", + type_name::() + ))) } /// Returns the factory bound with factory type `Interface`. @@ -279,39 +240,34 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::(); + let binding_providable = self.get_binding_providable::()?; - let interface_name = type_name::(); + if let Providable::Factory(binding_factory) = binding_providable { + return binding_factory + .cast::() + .map_err(unable_to_cast_binding::); + } - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a factory", + type_name::() + ))) + } - let binding_providable = binding + fn get_binding_providable( + &self, + ) -> error_stack::Result + where + Interface: 'static + ?Sized, + { + self.bindings + .get::()? .provide(self) .change_context(DIContainerError) .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Factory(binding_factory) => { - let factory_result = binding_factory.cast::(); - - match factory_result { - Ok(factory) => Ok(factory), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), - } + "Failed to resolve binding for interface '{}'", + type_name::() + )) } } @@ -332,6 +288,7 @@ mod tests use super::*; use crate::errors::injectable::ResolveError; + use crate::provider::IProvider; use crate::ptr::TransientPtr; #[test] @@ -376,11 +333,11 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::().to::(); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -425,13 +382,13 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container .bind::() .to_singleton::()?; - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); Ok(()) } @@ -475,7 +432,7 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::().to_factory(&|| { let user_manager: TransientPtr = @@ -484,7 +441,7 @@ mod tests user_manager }); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -546,7 +503,7 @@ mod tests di_container .bindings - .insert(TypeId::of::(), Rc::new(mock_provider)); + .set::(Box::new(mock_provider)); di_container.get::()?; @@ -635,7 +592,7 @@ mod tests di_container .bindings - .insert(TypeId::of::(), Rc::new(mock_provider)); + .set::(Box::new(mock_provider)); let first_number_rc = di_container.get_singleton::()?; @@ -723,7 +680,7 @@ mod tests di_container .bindings - .insert(TypeId::of::(), Rc::new(mock_provider)); + .set::(Box::new(mock_provider)); di_container.get_factory::()?; diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs new file mode 100644 index 0000000..b505321 --- /dev/null +++ b/src/di_container_binding_map.rs @@ -0,0 +1,55 @@ +use std::any::{type_name, TypeId}; + +use ahash::AHashMap; +use error_stack::report; + +use crate::{errors::di_container::DIContainerError, provider::IProvider}; + +pub struct DIContainerBindingMap +{ + bindings: AHashMap>, +} + +impl DIContainerBindingMap +{ + pub fn new() -> Self + { + Self { + bindings: AHashMap::new(), + } + } + + pub fn get(&self) -> error_stack::Result<&dyn IProvider, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::(); + + Ok(self + .bindings + .get(&interface_typeid) + .ok_or_else(|| { + report!(DIContainerError).attach_printable(format!( + "No binding exists for {}", + type_name::() + )) + })? + .as_ref()) + } + + pub fn set(&mut self, provider: Box) + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::(); + + self.bindings.insert(interface_typeid, provider); + } + + /// Only used by tests in the ``di_container`` module. + #[cfg(test)] + pub fn count(&self) -> usize + { + self.bindings.len() + } +} diff --git a/src/lib.rs b/src/lib.rs index 4e56ae0..d7314ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ pub mod castable_factory; pub mod libs; // Private +mod di_container_binding_map; mod provider; /// Shortcut for creating a DI container binding for a injectable without a declared interface. -- cgit v1.2.3-18-g5258