diff options
-rw-r--r-- | src/di_container.rs | 199 | ||||
-rw-r--r-- | src/di_container_binding_map.rs | 55 | ||||
-rw-r--r-- | src/lib.rs | 1 |
3 files changed, 134 insertions, 121 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index e59d8f4..e18fc3a 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -1,18 +1,29 @@ -use std::any::{type_name, TypeId}; +use std::any::type_name; use std::marker::PhantomData; -use std::rc::Rc; -use ahash::AHashMap; -use error_stack::{Report, ResultExt}; +use error_stack::{report, Report, ResultExt}; #[cfg(feature = "factory")] use crate::castable_factory::CastableFactory; +use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{BindingBuilderError, DIContainerError}; use crate::interfaces::injectable::Injectable; +use crate::libs::intertrait::cast::error::CastError; use crate::libs::intertrait::cast::{CastBox, CastRc}; -use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider}; +use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, TransientPtr}; +fn unable_to_cast_binding<Interface>(err: Report<CastError>) -> Report<DIContainerError> +where + Interface: 'static + ?Sized, +{ + err.change_context(DIContainerError) + .attach_printable(format!( + "Unable to cast binding for interface '{}'", + type_name::<Interface>() + )) +} + /// Binding builder for type `Interface` inside a [`DIContainer`]. pub struct BindingBuilder<'di_container_lt, Interface> where @@ -40,12 +51,9 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::<Interface>(); - - self.di_container.bindings.insert( - interface_typeid, - Rc::new(TransientTypeProvider::<Implementation>::new()), - ); + self.di_container + .bindings + .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new())); } /// Creates a binding of type `Interface` to a new singleton of type `Implementation` @@ -59,8 +67,6 @@ where where Implementation: Injectable, { - let interface_typeid = TypeId::of::<Interface>(); - let singleton: SingletonPtr<Implementation> = SingletonPtr::from( Implementation::resolve(self.di_container) .change_context(BindingBuilderError)?, @@ -68,7 +74,7 @@ where self.di_container .bindings - .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton))); + .set::<Interface>(Box::new(SingletonProvider::new(singleton))); Ok(()) } @@ -84,16 +90,13 @@ where Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory<Args, Return>, { - let interface_typeid = TypeId::of::<Interface>(); - let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.insert( - interface_typeid, - Rc::new(crate::provider::FactoryProvider::new( - crate::ptr::FactoryPtr::new(factory_impl), + self.di_container.bindings.set::<Interface>(Box::new( + crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( + factory_impl, )), - ); + )); } } @@ -144,7 +147,7 @@ where /// ``` pub struct DIContainer { - bindings: AHashMap<TypeId, Rc<dyn IProvider>>, + bindings: DIContainerBindingMap, } impl DIContainer @@ -154,7 +157,7 @@ impl DIContainer pub fn new() -> Self { Self { - bindings: AHashMap::new(), + bindings: DIContainerBindingMap::new(), } } @@ -180,39 +183,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); - - let interface_name = type_name::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Transient(binding_injectable) => { - let interface_result = binding_injectable.cast::<Interface>(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not injectable", - interface_name - ))), + if let Providable::Transient(binding_transient) = binding_providable { + return binding_transient + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not transient", + type_name::<Interface>() + ))) } /// Returns the singleton instance bound with `Interface`. @@ -229,39 +211,18 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); - - let interface_name = type_name::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; - - let binding_providable = binding - .provide(self) - .change_context(DIContainerError) - .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Singleton(binding_singleton) => { - let interface_result = binding_singleton.cast::<Interface>(); - - match interface_result { - Ok(interface) => Ok(interface), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a singleton", - interface_name - ))), + if let Providable::Singleton(binding_singleton) = binding_providable { + return binding_singleton + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a singleton", + type_name::<Interface>() + ))) } /// Returns the factory bound with factory type `Interface`. @@ -279,39 +240,34 @@ impl DIContainer where Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<Interface>(); + let binding_providable = self.get_binding_providable::<Interface>()?; - let interface_name = type_name::<Interface>(); + if let Providable::Factory(binding_factory) = binding_providable { + return binding_factory + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); + } - let binding = self.bindings.get(&interface_typeid).ok_or_else(|| { - Report::new(DIContainerError) - .attach_printable(format!("No binding exists for {}", interface_name)) - })?; + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a factory", + type_name::<Interface>() + ))) + } - let binding_providable = binding + fn get_binding_providable<Interface>( + &self, + ) -> error_stack::Result<Providable, DIContainerError> + where + Interface: 'static + ?Sized, + { + self.bindings + .get::<Interface>()? .provide(self) .change_context(DIContainerError) .attach_printable(format!( - "Failed to resolve binding for interface {}", - interface_name - ))?; - - match binding_providable { - Providable::Factory(binding_factory) => { - let factory_result = binding_factory.cast::<Interface>(); - - match factory_result { - Ok(factory) => Ok(factory), - Err(_) => Err(Report::new(DIContainerError).attach_printable( - format!("Unable to cast binding for {}", interface_name), - )), - } - } - _ => Err(Report::new(DIContainerError).attach_printable(format!( - "Binding for {} is not a factory", - interface_name - ))), - } + "Failed to resolve binding for interface '{}'", + type_name::<Interface>() + )) } } @@ -332,6 +288,7 @@ mod tests use super::*; use crate::errors::injectable::ResolveError; + use crate::provider::IProvider; use crate::ptr::TransientPtr; #[test] @@ -376,11 +333,11 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::<dyn IUserManager>().to::<UserManager>(); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -425,13 +382,13 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container .bind::<dyn IUserManager>() .to_singleton::<UserManager>()?; - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); Ok(()) } @@ -475,7 +432,7 @@ mod tests let mut di_container: DIContainer = DIContainer::new(); - assert_eq!(di_container.bindings.len(), 0); + assert_eq!(di_container.bindings.count(), 0); di_container.bind::<IUserManagerFactory>().to_factory(&|| { let user_manager: TransientPtr<dyn IUserManager> = @@ -484,7 +441,7 @@ mod tests user_manager }); - assert_eq!(di_container.bindings.len(), 1); + assert_eq!(di_container.bindings.count(), 1); } #[test] @@ -546,7 +503,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<dyn IUserManager>(), Rc::new(mock_provider)); + .set::<dyn IUserManager>(Box::new(mock_provider)); di_container.get::<dyn IUserManager>()?; @@ -635,7 +592,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider)); + .set::<dyn INumber>(Box::new(mock_provider)); let first_number_rc = di_container.get_singleton::<dyn INumber>()?; @@ -723,7 +680,7 @@ mod tests di_container .bindings - .insert(TypeId::of::<IUserManagerFactory>(), Rc::new(mock_provider)); + .set::<IUserManagerFactory>(Box::new(mock_provider)); di_container.get_factory::<IUserManagerFactory>()?; diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs new file mode 100644 index 0000000..b505321 --- /dev/null +++ b/src/di_container_binding_map.rs @@ -0,0 +1,55 @@ +use std::any::{type_name, TypeId}; + +use ahash::AHashMap; +use error_stack::report; + +use crate::{errors::di_container::DIContainerError, provider::IProvider}; + +pub struct DIContainerBindingMap +{ + bindings: AHashMap<TypeId, Box<dyn IProvider>>, +} + +impl DIContainerBindingMap +{ + pub fn new() -> Self + { + Self { + bindings: AHashMap::new(), + } + } + + pub fn get<Interface>(&self) -> error_stack::Result<&dyn IProvider, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + Ok(self + .bindings + .get(&interface_typeid) + .ok_or_else(|| { + report!(DIContainerError).attach_printable(format!( + "No binding exists for {}", + type_name::<Interface>() + )) + })? + .as_ref()) + } + + pub fn set<Interface>(&mut self, provider: Box<dyn IProvider>) + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + self.bindings.insert(interface_typeid, provider); + } + + /// Only used by tests in the ``di_container`` module. + #[cfg(test)] + pub fn count(&self) -> usize + { + self.bindings.len() + } +} @@ -22,6 +22,7 @@ pub mod castable_factory; pub mod libs; // Private +mod di_container_binding_map; mod provider; /// Shortcut for creating a DI container binding for a injectable without a declared interface. |