diff options
author | HampusM <hampus@hampusmat.com> | 2022-07-28 20:38:33 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-07-31 12:17:51 +0200 |
commit | 94b8e52eeefdabab98dfdb5bf4c91f75d778150c (patch) | |
tree | bf57b9d717778c6be9798e47828c9bed45f28aa4 /src/di_container.rs | |
parent | 545e8efddf217f300b26b930f8345d8573c30ec7 (diff) |
refactor: tidy up DI container internals
Diffstat (limited to 'src/di_container.rs')
-rw-r--r-- | src/di_container.rs | 199 |
1 files changed, 78 insertions, 121 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index e59d8f4..e18fc3a 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -1,18 +1,29 @@ -use std::any::{type_name, TypeId}; +use std::any::type_name; use std::marker::PhantomData; -use std::rc::Rc; -use ahash::AHashMap; -use error_stack::{Report, ResultExt}; +use error_stack::{report, Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; +use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; +use crate::libs::intertrait::cast::error::CastError; use crate::libs::intertrait::cast::{CastBox, CastRc}; -use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, TransientPtr}; +fn unable_to_cast_binding<Interface>(err: Report<CastError>) -> Report<DIContainerError> +where + Interface: 'static + ?Sized, +{ + err.change_context(DIContainerError) + .attach_printable(format!( + "Unable to cast binding for interface '{}'", + type_name::<Interface>() + )) +} + /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> where @@ -40,12 +51,9 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::<Interface>(); - - self.di_container.bindings.insert( - interface_typeid, - Rc::new(TransientTypeProvider::<Implementation>::new()), - ); + self.di_container + .bindings + .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new())); } /// Creates a binding of type `Interface` to a new singleton of type `Implementation` @@ -59,8 +67,6 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::<Interface>(); - let singleton: SingletonPtr<Implementation> = SingletonPtr::from( Implementation::resolve(self.di_container) .change_context(BindingBuilderError)?, @@ -68,7 +74,7 @@ where self.di_container .bindings - .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + .set::<Interface>(Box::new(SingletonProvider::new(singleton))); Ok(()) } @@ -84,16 +90,13 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory<Args, Return>, { - let interface_typeid = TypeId::of::<Interface>(); - let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.insert( - interface_typeid, - Rc::new(crate::provider::FactoryProvider::new( - crate::ptr::FactoryPtr::new(factory_impl), + self.di_container.bindings.set::<Interface>(Box::new( + crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( + factory_impl, )), - ); + )); } } @@ -144,7 +147,7 @@ where /// ``` pub struct DIContainer { - bindings: AHashMap<TypeId, Rc<dyn IProvider>>, + bindings: DIContainerBindingMap, } impl DIContainer @@ -154,7 +157,7 @@ impl DIContainer pub fn new() -> Self { Self { - bindings: AHashMap::new(), + bindings: DIContainerBindingMap::new(), } } @@ -180,39 +183,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); - - let interface_name = type_name::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Transient(binding_injectable) => { - let interface_result = binding_injectable.cast::<Interface>(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + if let Providable::Transient(binding_transient) = binding_providable { + return binding_transient + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not transient", + type_name::<Interface>() + ))) } /// Returns the singleton instance bound with `Interface`. @@ -229,39 +211,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); - - let interface_name = type_name::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Singleton(binding_singleton) => { - let interface_result = binding_singleton.cast::<Interface>(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a singleton", - interface_name - ))), + if let Providable::Singleton(binding_singleton) = binding_providable { + return binding_singleton + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a singleton", + type_name::<Interface>() + ))) } /// Returns the factory bound with factory type `Interface`. @@ -279,39 +240,34 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let interface_name = type_name::<Interface>(); + if let Providable::Factory(binding_factory) = binding_providable { + return binding_factory + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); + } - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a factory", + type_name::<Interface>() + ))) + } - let binding_providable = binding + fn get_binding_providable<Interface>( + &self, + ) -> error_stack::Result<Providable, DIContainerError> + where + Interface: 'static + ?Sized, + { + self.bindings + .get::<Interface>()? .provide(self) .change_context(DIContainerError) .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Factory(binding_factory) => { - let factory_result = binding_factory.cast::<Interface>(); - - match factory_result { - Ok(factory) => Ok(factory), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), - } + "Failed to resolve binding for interface '{}'", + type_name::<Interface>() + )) } } @@ -332,6 +288,7 @@ mod tests use super::*; use crate::errors::injectable::ResolveError; + use crate::provider::IProvider; use crate::ptr::TransientPtr; #[test] @@ -376,11 +333,11 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::<dyn IUserManager>().to::<UserManager>(); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -425,13 +382,13 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container .bind::<dyn IUserManager>() .to_singleton::<UserManager>()?; - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); Ok(()) } @@ -475,7 +432,7 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::<IUserManagerFactory>().to_factory(&|| { let user_manager: TransientPtr<dyn IUserManager> = @@ -484,7 +441,7 @@ mod tests user_manager }); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -546,7 +503,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<dyn IUserManager>(), Rc::new(mock_provider)); + .set::<dyn IUserManager>(Box::new(mock_provider)); di_container.get::<dyn IUserManager>()?; @@ -635,7 +592,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider)); + .set::<dyn INumber>(Box::new(mock_provider)); let first_number_rc = di_container.get_singleton::<dyn INumber>()?; @@ -723,7 +680,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<IUserManagerFactory>(), Rc::new(mock_provider)); + .set::<IUserManagerFactory>(Box::new(mock_provider)); di_container.get_factory::<IUserManagerFactory>()?; |