diff options
author | HampusM <hampus@hampusmat.com> | 2022-08-27 23:41:41 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-08-27 23:41:41 +0200 |
commit | e0f90a8e384615c79d7d51c66d19294d75e79391 (patch) | |
tree | f3df3d1cd92f7d4a978feaa5a9a5f773dd0901ee /src | |
parent | d4078c84a83d121a4e3492955359cedb3b404476 (diff) |
feat: implement named bindings
Diffstat (limited to 'src')
-rw-r--r-- | src/di_container.rs | 404 | ||||
-rw-r--r-- | src/di_container_binding_map.rs | 59 | ||||
-rw-r--r-- | src/errors/di_container.rs | 27 |
3 files changed, 446 insertions, 44 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index 85b0e7a..9d54261 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -54,13 +54,66 @@ use std::marker::PhantomData; use crate::castable_factory::CastableFactory; use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{ - BindingBuilderError, BindingScopeConfiguratorError, DIContainerError, + BindingBuilderError, BindingScopeConfiguratorError, BindingWhenConfiguratorError, + DIContainerError, }; use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast::{CastBox, CastRc}; use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, SomePtr}; +/// When configurator for a binding for type 'Interface' inside a [`DIContainer`]. +pub struct BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + di_container: &'di_container mut DIContainer, + interface_phantom: PhantomData<Interface>, +} + +impl<'di_container, Interface> BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + fn new(di_container: &'di_container mut DIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + } + } + + /// Configures the binding to have a name. + /// + /// # Errors + /// Will return Err if no binding for the interface already exists. + pub fn when_named( + &mut self, + name: &'static str, + ) -> Result<(), BindingWhenConfiguratorError> + { + let binding = self + .di_container + .bindings + .remove::<Interface>(None) + .map_or_else( + || { + Err(BindingWhenConfiguratorError::BindingNotFound(type_name::< + Interface, + >( + ))) + }, + Ok, + )?; + + self.di_container + .bindings + .set::<Interface>(Some(name), binding); + + Ok(()) + } +} + /// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`]. pub struct BindingScopeConfigurator<'di_container, Interface, Implementation> where @@ -90,18 +143,23 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) + pub fn in_transient_scope(&mut self) -> BindingWhenConfigurator<Interface> { - self.di_container - .bindings - .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new())); + self.di_container.bindings.set::<Interface>( + None, + Box::new(TransientTypeProvider::<Implementation>::new()), + ); + + BindingWhenConfigurator::new(self.di_container) } /// Configures the binding to be in a singleton scope. /// /// # Errors /// Will return Err if resolving the implementation fails. - pub fn in_singleton_scope(&mut self) -> Result<(), BindingScopeConfiguratorError> + pub fn in_singleton_scope( + &mut self, + ) -> Result<BindingWhenConfigurator<Interface>, BindingScopeConfiguratorError> { let singleton: SingletonPtr<Implementation> = SingletonPtr::from( Implementation::resolve(self.di_container, Vec::new()) @@ -110,9 +168,9 @@ where self.di_container .bindings - .set::<Interface>(Box::new(SingletonProvider::new(singleton))); + .set::<Interface>(None, Box::new(SingletonProvider::new(singleton))); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -152,7 +210,7 @@ where where Implementation: Injectable, { - if self.di_container.bindings.has::<Interface>() { + if self.di_container.bindings.has::<Interface>(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -178,13 +236,13 @@ where pub fn to_factory<Args, Return>( &mut self, factory_func: &'static dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>, - ) -> Result<(), BindingBuilderError> + ) -> Result<BindingWhenConfigurator<Interface>, BindingBuilderError> where Args: 'static, Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory<Args, Return>, { - if self.di_container.bindings.has::<Interface>() { + if self.di_container.bindings.has::<Interface>(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -192,13 +250,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::<Interface>(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::<Interface>( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -213,11 +272,11 @@ where pub fn to_default_factory<Return>( &mut self, factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>, - ) -> Result<(), BindingBuilderError> + ) -> Result<BindingWhenConfigurator<Interface>, BindingBuilderError> where Return: 'static + ?Sized, { - if self.di_container.bindings.has::<Interface>() { + if self.di_container.bindings.has::<Interface>(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -225,13 +284,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::<Interface>(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::<Interface>( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -265,26 +325,53 @@ impl DIContainer /// # Errors /// Will return `Err` if: /// - No binding for `Interface` exists - /// - Resolving the binding for `Interface` fails - /// - Casting the binding for `Interface` fails + /// - Resolving the binding for fails + /// - Casting the binding for fails pub fn get<Interface>(&self) -> Result<SomePtr<Interface>, DIContainerError> where Interface: 'static + ?Sized, { - self.get_bound::<Interface>(Vec::new()) + self.get_bound::<Interface>(Vec::new(), None) + } + + /// Returns the type bound with `Interface` and the specified name. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` with name `name` exists + /// - Resolving the binding fails + /// - Casting the binding for fails + pub fn get_named<Interface>( + &self, + name: &'static str, + ) -> Result<SomePtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + self.get_bound::<Interface>(Vec::new(), Some(name)) } #[doc(hidden)] pub fn get_bound<Interface>( &self, dependency_history: Vec<&'static str>, + name: Option<&'static str>, ) -> Result<SomePtr<Interface>, DIContainerError> where Interface: 'static + ?Sized, { let binding_providable = - self.get_binding_providable::<Interface>(dependency_history)?; + self.get_binding_providable::<Interface>(name, dependency_history)?; + + Self::handle_binding_providable(binding_providable) + } + fn handle_binding_providable<Interface>( + binding_providable: Providable, + ) -> Result<SomePtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { match binding_providable { Providable::Transient(transient_binding) => Ok(SomePtr::Transient( transient_binding.cast::<Interface>().map_err(|_| { @@ -318,13 +405,14 @@ impl DIContainer fn get_binding_providable<Interface>( &self, + name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result<Providable, DIContainerError> where Interface: 'static + ?Sized, { self.bindings - .get::<Interface>()? + .get::<Interface>(name)? .provide(self, dependency_history) .map_err(|err| DIContainerError::BindingResolveFailed { reason: err, @@ -494,6 +582,41 @@ mod tests } #[test] + fn can_bind_to_transient() -> Result<(), Box<dyn Error>> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_transient_scope(); + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_transient_scope() + .when_named("regular")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] fn can_bind_to_singleton() -> Result<(), Box<dyn Error>> { let mut di_container: DIContainer = DIContainer::new(); @@ -511,6 +634,24 @@ mod tests } #[test] + fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<dyn subjects::IUserManager>() + .to::<subjects::UserManager>()? + .in_singleton_scope()? + .when_named("cool")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() -> Result<(), Box<dyn Error>> { @@ -534,6 +675,32 @@ mod tests } #[test] + #[cfg(feature = "factory")] + fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>> + { + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::<IUserManagerFactory>() + .to_factory(&|| { + let user_manager: TransientPtr<dyn subjects::IUserManager> = + TransientPtr::new(subjects::UserManager::new()); + + user_manager + })? + .when_named("awesome")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] fn can_get() -> Result<(), Box<dyn Error>> { mock! { @@ -561,9 +728,48 @@ mod tests di_container .bindings - .set::<dyn subjects::IUserManager>(Box::new(mock_provider)); + .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider)); + + di_container + .get::<dyn subjects::IUserManager>()? + .transient()?; + + Ok(()) + } + + #[test] + fn can_get_named() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<Providable, InjectableError>; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + di_container + .bindings + .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider)); - di_container.get::<dyn subjects::IUserManager>()?; + di_container + .get_named::<dyn subjects::IUserManager>("special")? + .transient()?; Ok(()) } @@ -598,7 +804,7 @@ mod tests di_container .bindings - .set::<dyn subjects::INumber>(Box::new(mock_provider)); + .set::<dyn subjects::INumber>(None, Box::new(mock_provider)); let first_number_rc = di_container.get::<dyn subjects::INumber>()?.singleton()?; @@ -613,6 +819,53 @@ mod tests } #[test] + fn can_get_singleton_named() -> Result<(), Box<dyn Error>> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<Providable, InjectableError>; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(subjects::Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; + + mock_provider + .expect_provide() + .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider)); + + let first_number_rc = di_container + .get_named::<dyn subjects::INumber>("cool")? + .singleton()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container + .get_named::<dyn subjects::INumber>("cool")? + .singleton()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> Result<(), Box<dyn Error>> { @@ -688,9 +941,94 @@ mod tests di_container .bindings - .set::<IUserManagerFactory>(Box::new(mock_provider)); + .set::<IUserManagerFactory>(None, Box::new(mock_provider)); + + di_container.get::<IUserManagerFactory>()?.factory()?; - di_container.get::<IUserManagerFactory>()?; + Ok(()) + } + + #[test] + #[cfg(feature = "factory")] + fn can_get_factory_named() -> Result<(), Box<dyn Error>> + { + trait IUserManager + { + fn add_user(&mut self, user_id: i128); + + fn remove_user(&mut self, user_id: i128); + } + + struct UserManager + { + users: Vec<i128>, + } + + impl UserManager + { + fn new(users: Vec<i128>) -> Self + { + Self { users } + } + } + + impl IUserManager for UserManager + { + fn add_user(&mut self, user_id: i128) + { + self.users.push(user_id); + } + + fn remove_user(&mut self, user_id: i128) + { + let user_index = + self.users.iter().position(|user| *user == user_id).unwrap(); + + self.users.remove(user_index); + } + } + + use crate as syrette; + + #[crate::factory] + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>; + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result<Providable, InjectableError>; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Factory(crate::ptr::FactoryPtr::new( + CastableFactory::new(&|users| { + let user_manager: TransientPtr<dyn IUserManager> = + TransientPtr::new(UserManager::new(users)); + + user_manager + }), + ))) + }); + + di_container + .bindings + .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider)); + + di_container + .get_named::<IUserManagerFactory>("special")? + .factory()?; Ok(()) } diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 20d040f..d4b46f2 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -4,9 +4,16 @@ use ahash::AHashMap; use crate::{errors::di_container::DIContainerError, provider::IProvider}; +#[derive(Debug, PartialEq, Eq, Hash)] +struct DIContainerBindingKey +{ + type_id: TypeId, + name: Option<&'static str>, +} + pub struct DIContainerBindingMap { - bindings: AHashMap<TypeId, Box<dyn IProvider>>, + bindings: AHashMap<DIContainerBindingKey, Box<dyn IProvider>>, } impl DIContainerBindingMap @@ -18,7 +25,10 @@ impl DIContainerBindingMap } } - pub fn get<Interface>(&self) -> Result<&dyn IProvider, DIContainerError> + pub fn get<Interface>( + &self, + name: Option<&'static str>, + ) -> Result<&dyn IProvider, DIContainerError> where Interface: 'static + ?Sized, { @@ -26,27 +36,60 @@ impl DIContainerBindingMap Ok(self .bindings - .get(&interface_typeid) - .ok_or_else(|| DIContainerError::BindingNotFound(type_name::<Interface>()))? + .get(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) + .ok_or_else(|| DIContainerError::BindingNotFound { + interface: type_name::<Interface>(), + name, + })? .as_ref()) } - pub fn set<Interface>(&mut self, provider: Box<dyn IProvider>) + pub fn set<Interface>( + &mut self, + name: Option<&'static str>, + provider: Box<dyn IProvider>, + ) where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + self.bindings.insert( + DIContainerBindingKey { + type_id: interface_typeid, + name, + }, + provider, + ); + } + + pub fn remove<Interface>( + &mut self, + name: Option<&'static str>, + ) -> Option<Box<dyn IProvider>> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::<Interface>(); - self.bindings.insert(interface_typeid, provider); + self.bindings.remove(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } - pub fn has<Interface>(&self) -> bool + pub fn has<Interface>(&self, name: Option<&'static str>) -> bool where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::<Interface>(); - self.bindings.contains_key(&interface_typeid) + self.bindings.contains_key(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } /// Only used by tests in the `di_container` module. diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index 65cd9d1..82a3d55 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -26,9 +26,19 @@ pub enum DIContainerError interface: &'static str, }, - /// No binding exists for a interface. - #[error("No binding exists for interface '{0}'")] - BindingNotFound(&'static str), + /// No binding exists for a interface (and optionally a name). + #[error( + "No binding exists for interface '{interface}' {}", + .name.map_or_else(String::new, |name| format!("with name '{}'", name)) + )] + BindingNotFound + { + /// The interface that doesn't have a binding. + interface: &'static str, + + /// The name of the binding if one exists. + name: Option<&'static str>, + }, } /// Error type for [`BindingBuilder`]. @@ -52,3 +62,14 @@ pub enum BindingScopeConfiguratorError #[error("Resolving the given singleton failed")] SingletonResolveFailed(#[from] InjectableError), } + +/// Error type for [`BindingWhenConfigurator`]. +/// +/// [`BindingWhenConfigurator`]: crate::di_container::BindingWhenConfigurator +#[derive(thiserror::Error, Debug)] +pub enum BindingWhenConfiguratorError +{ + /// A binding for a interface wasn't found. + #[error("A binding for interface '{0}' wasn't found'")] + BindingNotFound(&'static str), +} |