From e0f90a8e384615c79d7d51c66d19294d75e79391 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 27 Aug 2022 23:41:41 +0200 Subject: feat: implement named bindings --- src/di_container.rs | 404 ++++++++++++++++++++++++++++++++++++---- src/di_container_binding_map.rs | 59 +++++- src/errors/di_container.rs | 27 ++- 3 files changed, 446 insertions(+), 44 deletions(-) (limited to 'src') diff --git a/src/di_container.rs b/src/di_container.rs index 85b0e7a..9d54261 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -54,13 +54,66 @@ use std::marker::PhantomData; use crate::castable_factory::CastableFactory; use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{ - BindingBuilderError, BindingScopeConfiguratorError, DIContainerError, + BindingBuilderError, BindingScopeConfiguratorError, BindingWhenConfiguratorError, + DIContainerError, }; use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast::{CastBox, CastRc}; use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, SomePtr}; +/// When configurator for a binding for type 'Interface' inside a [`DIContainer`]. +pub struct BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + di_container: &'di_container mut DIContainer, + interface_phantom: PhantomData, +} + +impl<'di_container, Interface> BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + fn new(di_container: &'di_container mut DIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + } + } + + /// Configures the binding to have a name. + /// + /// # Errors + /// Will return Err if no binding for the interface already exists. + pub fn when_named( + &mut self, + name: &'static str, + ) -> Result<(), BindingWhenConfiguratorError> + { + let binding = self + .di_container + .bindings + .remove::(None) + .map_or_else( + || { + Err(BindingWhenConfiguratorError::BindingNotFound(type_name::< + Interface, + >( + ))) + }, + Ok, + )?; + + self.di_container + .bindings + .set::(Some(name), binding); + + Ok(()) + } +} + /// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`]. pub struct BindingScopeConfigurator<'di_container, Interface, Implementation> where @@ -90,18 +143,23 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) + pub fn in_transient_scope(&mut self) -> BindingWhenConfigurator { - self.di_container - .bindings - .set::(Box::new(TransientTypeProvider::::new())); + self.di_container.bindings.set::( + None, + Box::new(TransientTypeProvider::::new()), + ); + + BindingWhenConfigurator::new(self.di_container) } /// Configures the binding to be in a singleton scope. /// /// # Errors /// Will return Err if resolving the implementation fails. - pub fn in_singleton_scope(&mut self) -> Result<(), BindingScopeConfiguratorError> + pub fn in_singleton_scope( + &mut self, + ) -> Result, BindingScopeConfiguratorError> { let singleton: SingletonPtr = SingletonPtr::from( Implementation::resolve(self.di_container, Vec::new()) @@ -110,9 +168,9 @@ where self.di_container .bindings - .set::(Box::new(SingletonProvider::new(singleton))); + .set::(None, Box::new(SingletonProvider::new(singleton))); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -152,7 +210,7 @@ where where Implementation: Injectable, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -178,13 +236,13 @@ where pub fn to_factory( &mut self, factory_func: &'static dyn Fn>, - ) -> Result<(), BindingBuilderError> + ) -> Result, BindingBuilderError> where Args: 'static, Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -192,13 +250,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -213,11 +272,11 @@ where pub fn to_default_factory( &mut self, factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr>, - ) -> Result<(), BindingBuilderError> + ) -> Result, BindingBuilderError> where Return: 'static + ?Sized, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -225,13 +284,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -265,26 +325,53 @@ impl DIContainer /// # Errors /// Will return `Err` if: /// - No binding for `Interface` exists - /// - Resolving the binding for `Interface` fails - /// - Casting the binding for `Interface` fails + /// - Resolving the binding for fails + /// - Casting the binding for fails pub fn get(&self) -> Result, DIContainerError> where Interface: 'static + ?Sized, { - self.get_bound::(Vec::new()) + self.get_bound::(Vec::new(), None) + } + + /// Returns the type bound with `Interface` and the specified name. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` with name `name` exists + /// - Resolving the binding fails + /// - Casting the binding for fails + pub fn get_named( + &self, + name: &'static str, + ) -> Result, DIContainerError> + where + Interface: 'static + ?Sized, + { + self.get_bound::(Vec::new(), Some(name)) } #[doc(hidden)] pub fn get_bound( &self, dependency_history: Vec<&'static str>, + name: Option<&'static str>, ) -> Result, DIContainerError> where Interface: 'static + ?Sized, { let binding_providable = - self.get_binding_providable::(dependency_history)?; + self.get_binding_providable::(name, dependency_history)?; + + Self::handle_binding_providable(binding_providable) + } + fn handle_binding_providable( + binding_providable: Providable, + ) -> Result, DIContainerError> + where + Interface: 'static + ?Sized, + { match binding_providable { Providable::Transient(transient_binding) => Ok(SomePtr::Transient( transient_binding.cast::().map_err(|_| { @@ -318,13 +405,14 @@ impl DIContainer fn get_binding_providable( &self, + name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result where Interface: 'static + ?Sized, { self.bindings - .get::()? + .get::(name)? .provide(self, dependency_history) .map_err(|err| DIContainerError::BindingResolveFailed { reason: err, @@ -493,6 +581,41 @@ mod tests Ok(()) } + #[test] + fn can_bind_to_transient() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_transient_scope(); + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + fn can_bind_to_transient_when_named() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_transient_scope() + .when_named("regular")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] fn can_bind_to_singleton() -> Result<(), Box> { @@ -510,6 +633,24 @@ mod tests Ok(()) } + #[test] + fn can_bind_to_singleton_when_named() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_singleton_scope()? + .when_named("cool")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() -> Result<(), Box> @@ -533,6 +674,32 @@ mod tests Ok(()) } + #[test] + #[cfg(feature = "factory")] + fn can_bind_to_factory_when_named() -> Result<(), Box> + { + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to_factory(&|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); + + user_manager + })? + .when_named("awesome")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] fn can_get() -> Result<(), Box> { @@ -561,9 +728,48 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); + + di_container + .get::()? + .transient()?; + + Ok(()) + } + + #[test] + fn can_get_named() -> Result<(), Box> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + di_container + .bindings + .set::(Some("special"), Box::new(mock_provider)); - di_container.get::()?; + di_container + .get_named::("special")? + .transient()?; Ok(()) } @@ -598,7 +804,7 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); let first_number_rc = di_container.get::()?.singleton()?; @@ -612,6 +818,53 @@ mod tests Ok(()) } + #[test] + fn can_get_singleton_named() -> Result<(), Box> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(subjects::Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; + + mock_provider + .expect_provide() + .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .set::(Some("cool"), Box::new(mock_provider)); + + let first_number_rc = di_container + .get_named::("cool")? + .singleton()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container + .get_named::("cool")? + .singleton()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> Result<(), Box> @@ -688,9 +941,94 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); + + di_container.get::()?.factory()?; - di_container.get::()?; + Ok(()) + } + + #[test] + #[cfg(feature = "factory")] + fn can_get_factory_named() -> Result<(), Box> + { + trait IUserManager + { + fn add_user(&mut self, user_id: i128); + + fn remove_user(&mut self, user_id: i128); + } + + struct UserManager + { + users: Vec, + } + + impl UserManager + { + fn new(users: Vec) -> Self + { + Self { users } + } + } + + impl IUserManager for UserManager + { + fn add_user(&mut self, user_id: i128) + { + self.users.push(user_id); + } + + fn remove_user(&mut self, user_id: i128) + { + let user_index = + self.users.iter().position(|user| *user == user_id).unwrap(); + + self.users.remove(user_index); + } + } + + use crate as syrette; + + #[crate::factory] + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(Vec,), dyn IUserManager>; + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Factory(crate::ptr::FactoryPtr::new( + CastableFactory::new(&|users| { + let user_manager: TransientPtr = + TransientPtr::new(UserManager::new(users)); + + user_manager + }), + ))) + }); + + di_container + .bindings + .set::(Some("special"), Box::new(mock_provider)); + + di_container + .get_named::("special")? + .factory()?; Ok(()) } diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 20d040f..d4b46f2 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -4,9 +4,16 @@ use ahash::AHashMap; use crate::{errors::di_container::DIContainerError, provider::IProvider}; +#[derive(Debug, PartialEq, Eq, Hash)] +struct DIContainerBindingKey +{ + type_id: TypeId, + name: Option<&'static str>, +} + pub struct DIContainerBindingMap { - bindings: AHashMap>, + bindings: AHashMap>, } impl DIContainerBindingMap @@ -18,7 +25,10 @@ impl DIContainerBindingMap } } - pub fn get(&self) -> Result<&dyn IProvider, DIContainerError> + pub fn get( + &self, + name: Option<&'static str>, + ) -> Result<&dyn IProvider, DIContainerError> where Interface: 'static + ?Sized, { @@ -26,27 +36,60 @@ impl DIContainerBindingMap Ok(self .bindings - .get(&interface_typeid) - .ok_or_else(|| DIContainerError::BindingNotFound(type_name::()))? + .get(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) + .ok_or_else(|| DIContainerError::BindingNotFound { + interface: type_name::(), + name, + })? .as_ref()) } - pub fn set(&mut self, provider: Box) + pub fn set( + &mut self, + name: Option<&'static str>, + provider: Box, + ) where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::(); + + self.bindings.insert( + DIContainerBindingKey { + type_id: interface_typeid, + name, + }, + provider, + ); + } + + pub fn remove( + &mut self, + name: Option<&'static str>, + ) -> Option> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::(); - self.bindings.insert(interface_typeid, provider); + self.bindings.remove(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } - pub fn has(&self) -> bool + pub fn has(&self, name: Option<&'static str>) -> bool where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::(); - self.bindings.contains_key(&interface_typeid) + self.bindings.contains_key(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } /// Only used by tests in the `di_container` module. diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index 65cd9d1..82a3d55 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -26,9 +26,19 @@ pub enum DIContainerError interface: &'static str, }, - /// No binding exists for a interface. - #[error("No binding exists for interface '{0}'")] - BindingNotFound(&'static str), + /// No binding exists for a interface (and optionally a name). + #[error( + "No binding exists for interface '{interface}' {}", + .name.map_or_else(String::new, |name| format!("with name '{}'", name)) + )] + BindingNotFound + { + /// The interface that doesn't have a binding. + interface: &'static str, + + /// The name of the binding if one exists. + name: Option<&'static str>, + }, } /// Error type for [`BindingBuilder`]. @@ -52,3 +62,14 @@ pub enum BindingScopeConfiguratorError #[error("Resolving the given singleton failed")] SingletonResolveFailed(#[from] InjectableError), } + +/// Error type for [`BindingWhenConfigurator`]. +/// +/// [`BindingWhenConfigurator`]: crate::di_container::BindingWhenConfigurator +#[derive(thiserror::Error, Debug)] +pub enum BindingWhenConfiguratorError +{ + /// A binding for a interface wasn't found. + #[error("A binding for interface '{0}' wasn't found'")] + BindingNotFound(&'static str), +} -- cgit v1.2.3-18-g5258