aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-08-27 23:41:41 +0200
committerHampusM <hampus@hampusmat.com>2022-08-27 23:41:41 +0200
commite0f90a8e384615c79d7d51c66d19294d75e79391 (patch)
treef3df3d1cd92f7d4a978feaa5a9a5f773dd0901ee /src
parentd4078c84a83d121a4e3492955359cedb3b404476 (diff)
feat: implement named bindings
Diffstat (limited to 'src')
-rw-r--r--src/di_container.rs404
-rw-r--r--src/di_container_binding_map.rs59
-rw-r--r--src/errors/di_container.rs27
3 files changed, 446 insertions, 44 deletions
diff --git a/src/di_container.rs b/src/di_container.rs
index 85b0e7a..9d54261 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -54,13 +54,66 @@ use std::marker::PhantomData;
use crate::castable_factory::CastableFactory;
use crate::di_container_binding_map::DIContainerBindingMap;
use crate::errors::di_container::{
- BindingBuilderError, BindingScopeConfiguratorError, DIContainerError,
+ BindingBuilderError, BindingScopeConfiguratorError, BindingWhenConfiguratorError,
+ DIContainerError,
};
use crate::interfaces::injectable::Injectable;
use crate::libs::intertrait::cast::{CastBox, CastRc};
use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};
use crate::ptr::{SingletonPtr, SomePtr};
+/// When configurator for a binding for type 'Interface' inside a [`DIContainer`].
+pub struct BindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ di_container: &'di_container mut DIContainer,
+ interface_phantom: PhantomData<Interface>,
+}
+
+impl<'di_container, Interface> BindingWhenConfigurator<'di_container, Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ fn new(di_container: &'di_container mut DIContainer) -> Self
+ {
+ Self {
+ di_container,
+ interface_phantom: PhantomData,
+ }
+ }
+
+ /// Configures the binding to have a name.
+ ///
+ /// # Errors
+ /// Will return Err if no binding for the interface already exists.
+ pub fn when_named(
+ &mut self,
+ name: &'static str,
+ ) -> Result<(), BindingWhenConfiguratorError>
+ {
+ let binding = self
+ .di_container
+ .bindings
+ .remove::<Interface>(None)
+ .map_or_else(
+ || {
+ Err(BindingWhenConfiguratorError::BindingNotFound(type_name::<
+ Interface,
+ >(
+ )))
+ },
+ Ok,
+ )?;
+
+ self.di_container
+ .bindings
+ .set::<Interface>(Some(name), binding);
+
+ Ok(())
+ }
+}
+
/// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`].
pub struct BindingScopeConfigurator<'di_container, Interface, Implementation>
where
@@ -90,18 +143,23 @@ where
/// Configures the binding to be in a transient scope.
///
/// This is the default.
- pub fn in_transient_scope(&mut self)
+ pub fn in_transient_scope(&mut self) -> BindingWhenConfigurator<Interface>
{
- self.di_container
- .bindings
- .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new()));
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(TransientTypeProvider::<Implementation>::new()),
+ );
+
+ BindingWhenConfigurator::new(self.di_container)
}
/// Configures the binding to be in a singleton scope.
///
/// # Errors
/// Will return Err if resolving the implementation fails.
- pub fn in_singleton_scope(&mut self) -> Result<(), BindingScopeConfiguratorError>
+ pub fn in_singleton_scope(
+ &mut self,
+ ) -> Result<BindingWhenConfigurator<Interface>, BindingScopeConfiguratorError>
{
let singleton: SingletonPtr<Implementation> = SingletonPtr::from(
Implementation::resolve(self.di_container, Vec::new())
@@ -110,9 +168,9 @@ where
self.di_container
.bindings
- .set::<Interface>(Box::new(SingletonProvider::new(singleton)));
+ .set::<Interface>(None, Box::new(SingletonProvider::new(singleton)));
- Ok(())
+ Ok(BindingWhenConfigurator::new(self.di_container))
}
}
@@ -152,7 +210,7 @@ where
where
Implementation: Injectable,
{
- if self.di_container.bindings.has::<Interface>() {
+ if self.di_container.bindings.has::<Interface>(None) {
return Err(BindingBuilderError::BindingAlreadyExists(type_name::<
Interface,
>()));
@@ -178,13 +236,13 @@ where
pub fn to_factory<Args, Return>(
&mut self,
factory_func: &'static dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>,
- ) -> Result<(), BindingBuilderError>
+ ) -> Result<BindingWhenConfigurator<Interface>, BindingBuilderError>
where
Args: 'static,
Return: 'static + ?Sized,
Interface: crate::interfaces::factory::IFactory<Args, Return>,
{
- if self.di_container.bindings.has::<Interface>() {
+ if self.di_container.bindings.has::<Interface>(None) {
return Err(BindingBuilderError::BindingAlreadyExists(type_name::<
Interface,
>()));
@@ -192,13 +250,14 @@ where
let factory_impl = CastableFactory::new(factory_func);
- self.di_container.bindings.set::<Interface>(Box::new(
- crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new(
- factory_impl,
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::FactoryProvider::new(
+ crate::ptr::FactoryPtr::new(factory_impl),
)),
- ));
+ );
- Ok(())
+ Ok(BindingWhenConfigurator::new(self.di_container))
}
/// Creates a binding of type `Interface` to a factory that takes no arguments
@@ -213,11 +272,11 @@ where
pub fn to_default_factory<Return>(
&mut self,
factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>,
- ) -> Result<(), BindingBuilderError>
+ ) -> Result<BindingWhenConfigurator<Interface>, BindingBuilderError>
where
Return: 'static + ?Sized,
{
- if self.di_container.bindings.has::<Interface>() {
+ if self.di_container.bindings.has::<Interface>(None) {
return Err(BindingBuilderError::BindingAlreadyExists(type_name::<
Interface,
>()));
@@ -225,13 +284,14 @@ where
let factory_impl = CastableFactory::new(factory_func);
- self.di_container.bindings.set::<Interface>(Box::new(
- crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new(
- factory_impl,
+ self.di_container.bindings.set::<Interface>(
+ None,
+ Box::new(crate::provider::FactoryProvider::new(
+ crate::ptr::FactoryPtr::new(factory_impl),
)),
- ));
+ );
- Ok(())
+ Ok(BindingWhenConfigurator::new(self.di_container))
}
}
@@ -265,26 +325,53 @@ impl DIContainer
/// # Errors
/// Will return `Err` if:
/// - No binding for `Interface` exists
- /// - Resolving the binding for `Interface` fails
- /// - Casting the binding for `Interface` fails
+ /// - Resolving the binding for fails
+ /// - Casting the binding for fails
pub fn get<Interface>(&self) -> Result<SomePtr<Interface>, DIContainerError>
where
Interface: 'static + ?Sized,
{
- self.get_bound::<Interface>(Vec::new())
+ self.get_bound::<Interface>(Vec::new(), None)
+ }
+
+ /// Returns the type bound with `Interface` and the specified name.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` with name `name` exists
+ /// - Resolving the binding fails
+ /// - Casting the binding for fails
+ pub fn get_named<Interface>(
+ &self,
+ name: &'static str,
+ ) -> Result<SomePtr<Interface>, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.get_bound::<Interface>(Vec::new(), Some(name))
}
#[doc(hidden)]
pub fn get_bound<Interface>(
&self,
dependency_history: Vec<&'static str>,
+ name: Option<&'static str>,
) -> Result<SomePtr<Interface>, DIContainerError>
where
Interface: 'static + ?Sized,
{
let binding_providable =
- self.get_binding_providable::<Interface>(dependency_history)?;
+ self.get_binding_providable::<Interface>(name, dependency_history)?;
+
+ Self::handle_binding_providable(binding_providable)
+ }
+ fn handle_binding_providable<Interface>(
+ binding_providable: Providable,
+ ) -> Result<SomePtr<Interface>, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
match binding_providable {
Providable::Transient(transient_binding) => Ok(SomePtr::Transient(
transient_binding.cast::<Interface>().map_err(|_| {
@@ -318,13 +405,14 @@ impl DIContainer
fn get_binding_providable<Interface>(
&self,
+ name: Option<&'static str>,
dependency_history: Vec<&'static str>,
) -> Result<Providable, DIContainerError>
where
Interface: 'static + ?Sized,
{
self.bindings
- .get::<Interface>()?
+ .get::<Interface>(name)?
.provide(self, dependency_history)
.map_err(|err| DIContainerError::BindingResolveFailed {
reason: err,
@@ -494,6 +582,41 @@ mod tests
}
#[test]
+ fn can_bind_to_transient() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: DIContainer = DIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope();
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_bind_to_transient_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: DIContainer = DIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_transient_scope()
+ .when_named("regular")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
fn can_bind_to_singleton() -> Result<(), Box<dyn Error>>
{
let mut di_container: DIContainer = DIContainer::new();
@@ -511,6 +634,24 @@ mod tests
}
#[test]
+ fn can_bind_to_singleton_when_named() -> Result<(), Box<dyn Error>>
+ {
+ let mut di_container: DIContainer = DIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<dyn subjects::IUserManager>()
+ .to::<subjects::UserManager>()?
+ .in_singleton_scope()?
+ .when_named("cool")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
#[cfg(feature = "factory")]
fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
{
@@ -534,6 +675,32 @@ mod tests
}
#[test]
+ #[cfg(feature = "factory")]
+ fn can_bind_to_factory_when_named() -> Result<(), Box<dyn Error>>
+ {
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ assert_eq!(di_container.bindings.count(), 0);
+
+ di_container
+ .bind::<IUserManagerFactory>()
+ .to_factory(&|| {
+ let user_manager: TransientPtr<dyn subjects::IUserManager> =
+ TransientPtr::new(subjects::UserManager::new());
+
+ user_manager
+ })?
+ .when_named("awesome")?;
+
+ assert_eq!(di_container.bindings.count(), 1);
+
+ Ok(())
+ }
+
+ #[test]
fn can_get() -> Result<(), Box<dyn Error>>
{
mock! {
@@ -561,9 +728,48 @@ mod tests
di_container
.bindings
- .set::<dyn subjects::IUserManager>(Box::new(mock_provider));
+ .set::<dyn subjects::IUserManager>(None, Box::new(mock_provider));
+
+ di_container
+ .get::<dyn subjects::IUserManager>()?
+ .transient()?;
+
+ Ok(())
+ }
+
+ #[test]
+ fn can_get_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ impl IProvider for Provider
+ {
+ fn provide(
+ &self,
+ di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<Providable, InjectableError>;
+ }
+ }
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(Providable::Transient(TransientPtr::new(
+ subjects::UserManager::new(),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<dyn subjects::IUserManager>(Some("special"), Box::new(mock_provider));
- di_container.get::<dyn subjects::IUserManager>()?;
+ di_container
+ .get_named::<dyn subjects::IUserManager>("special")?
+ .transient()?;
Ok(())
}
@@ -598,7 +804,7 @@ mod tests
di_container
.bindings
- .set::<dyn subjects::INumber>(Box::new(mock_provider));
+ .set::<dyn subjects::INumber>(None, Box::new(mock_provider));
let first_number_rc = di_container.get::<dyn subjects::INumber>()?.singleton()?;
@@ -613,6 +819,53 @@ mod tests
}
#[test]
+ fn can_get_singleton_named() -> Result<(), Box<dyn Error>>
+ {
+ mock! {
+ Provider {}
+
+ impl IProvider for Provider
+ {
+ fn provide(
+ &self,
+ di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<Providable, InjectableError>;
+ }
+ }
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = SingletonPtr::new(subjects::Number::new());
+
+ SingletonPtr::get_mut(&mut singleton).unwrap().num = 2820;
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .set::<dyn subjects::INumber>(Some("cool"), Box::new(mock_provider));
+
+ let first_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")?
+ .singleton()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container
+ .get_named::<dyn subjects::INumber>("cool")?
+ .singleton()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[test]
#[cfg(feature = "factory")]
fn can_get_factory() -> Result<(), Box<dyn Error>>
{
@@ -688,9 +941,94 @@ mod tests
di_container
.bindings
- .set::<IUserManagerFactory>(Box::new(mock_provider));
+ .set::<IUserManagerFactory>(None, Box::new(mock_provider));
+
+ di_container.get::<IUserManagerFactory>()?.factory()?;
- di_container.get::<IUserManagerFactory>()?;
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "factory")]
+ fn can_get_factory_named() -> Result<(), Box<dyn Error>>
+ {
+ trait IUserManager
+ {
+ fn add_user(&mut self, user_id: i128);
+
+ fn remove_user(&mut self, user_id: i128);
+ }
+
+ struct UserManager
+ {
+ users: Vec<i128>,
+ }
+
+ impl UserManager
+ {
+ fn new(users: Vec<i128>) -> Self
+ {
+ Self { users }
+ }
+ }
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&mut self, user_id: i128)
+ {
+ self.users.push(user_id);
+ }
+
+ fn remove_user(&mut self, user_id: i128)
+ {
+ let user_index =
+ self.users.iter().position(|user| *user == user_id).unwrap();
+
+ self.users.remove(user_index);
+ }
+ }
+
+ use crate as syrette;
+
+ #[crate::factory]
+ type IUserManagerFactory =
+ dyn crate::interfaces::factory::IFactory<(Vec<i128>,), dyn IUserManager>;
+
+ mock! {
+ Provider {}
+
+ impl IProvider for Provider
+ {
+ fn provide(
+ &self,
+ di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
+ ) -> Result<Providable, InjectableError>;
+ }
+ }
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ mock_provider.expect_provide().returning(|_, _| {
+ Ok(Providable::Factory(crate::ptr::FactoryPtr::new(
+ CastableFactory::new(&|users| {
+ let user_manager: TransientPtr<dyn IUserManager> =
+ TransientPtr::new(UserManager::new(users));
+
+ user_manager
+ }),
+ )))
+ });
+
+ di_container
+ .bindings
+ .set::<IUserManagerFactory>(Some("special"), Box::new(mock_provider));
+
+ di_container
+ .get_named::<IUserManagerFactory>("special")?
+ .factory()?;
Ok(())
}
diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs
index 20d040f..d4b46f2 100644
--- a/src/di_container_binding_map.rs
+++ b/src/di_container_binding_map.rs
@@ -4,9 +4,16 @@ use ahash::AHashMap;
use crate::{errors::di_container::DIContainerError, provider::IProvider};
+#[derive(Debug, PartialEq, Eq, Hash)]
+struct DIContainerBindingKey
+{
+ type_id: TypeId,
+ name: Option<&'static str>,
+}
+
pub struct DIContainerBindingMap
{
- bindings: AHashMap<TypeId, Box<dyn IProvider>>,
+ bindings: AHashMap<DIContainerBindingKey, Box<dyn IProvider>>,
}
impl DIContainerBindingMap
@@ -18,7 +25,10 @@ impl DIContainerBindingMap
}
}
- pub fn get<Interface>(&self) -> Result<&dyn IProvider, DIContainerError>
+ pub fn get<Interface>(
+ &self,
+ name: Option<&'static str>,
+ ) -> Result<&dyn IProvider, DIContainerError>
where
Interface: 'static + ?Sized,
{
@@ -26,27 +36,60 @@ impl DIContainerBindingMap
Ok(self
.bindings
- .get(&interface_typeid)
- .ok_or_else(|| DIContainerError::BindingNotFound(type_name::<Interface>()))?
+ .get(&DIContainerBindingKey {
+ type_id: interface_typeid,
+ name,
+ })
+ .ok_or_else(|| DIContainerError::BindingNotFound {
+ interface: type_name::<Interface>(),
+ name,
+ })?
.as_ref())
}
- pub fn set<Interface>(&mut self, provider: Box<dyn IProvider>)
+ pub fn set<Interface>(
+ &mut self,
+ name: Option<&'static str>,
+ provider: Box<dyn IProvider>,
+ ) where
+ Interface: 'static + ?Sized,
+ {
+ let interface_typeid = TypeId::of::<Interface>();
+
+ self.bindings.insert(
+ DIContainerBindingKey {
+ type_id: interface_typeid,
+ name,
+ },
+ provider,
+ );
+ }
+
+ pub fn remove<Interface>(
+ &mut self,
+ name: Option<&'static str>,
+ ) -> Option<Box<dyn IProvider>>
where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
- self.bindings.insert(interface_typeid, provider);
+ self.bindings.remove(&DIContainerBindingKey {
+ type_id: interface_typeid,
+ name,
+ })
}
- pub fn has<Interface>(&self) -> bool
+ pub fn has<Interface>(&self, name: Option<&'static str>) -> bool
where
Interface: 'static + ?Sized,
{
let interface_typeid = TypeId::of::<Interface>();
- self.bindings.contains_key(&interface_typeid)
+ self.bindings.contains_key(&DIContainerBindingKey {
+ type_id: interface_typeid,
+ name,
+ })
}
/// Only used by tests in the `di_container` module.
diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs
index 65cd9d1..82a3d55 100644
--- a/src/errors/di_container.rs
+++ b/src/errors/di_container.rs
@@ -26,9 +26,19 @@ pub enum DIContainerError
interface: &'static str,
},
- /// No binding exists for a interface.
- #[error("No binding exists for interface '{0}'")]
- BindingNotFound(&'static str),
+ /// No binding exists for a interface (and optionally a name).
+ #[error(
+ "No binding exists for interface '{interface}' {}",
+ .name.map_or_else(String::new, |name| format!("with name '{}'", name))
+ )]
+ BindingNotFound
+ {
+ /// The interface that doesn't have a binding.
+ interface: &'static str,
+
+ /// The name of the binding if one exists.
+ name: Option<&'static str>,
+ },
}
/// Error type for [`BindingBuilder`].
@@ -52,3 +62,14 @@ pub enum BindingScopeConfiguratorError
#[error("Resolving the given singleton failed")]
SingletonResolveFailed(#[from] InjectableError),
}
+
+/// Error type for [`BindingWhenConfigurator`].
+///
+/// [`BindingWhenConfigurator`]: crate::di_container::BindingWhenConfigurator
+#[derive(thiserror::Error, Debug)]
+pub enum BindingWhenConfiguratorError
+{
+ /// A binding for a interface wasn't found.
+ #[error("A binding for interface '{0}' wasn't found'")]
+ BindingNotFound(&'static str),
+}