aboutsummaryrefslogtreecommitdiff
path: root/src/di_container.rs
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-07-28 20:38:33 +0200
committerHampusM <hampus@hampusmat.com>2022-07-31 12:17:51 +0200
commit94b8e52eeefdabab98dfdb5bf4c91f75d778150c (patch)
treebf57b9d717778c6be9798e47828c9bed45f28aa4 /src/di_container.rs
parent545e8efddf217f300b26b930f8345d8573c30ec7 (diff)
refactor: tidy up DI container internals
Diffstat (limited to 'src/di_container.rs')
-rw-r--r--src/di_container.rs199
1 files changed, 78 insertions, 121 deletions
diff --git a/src/di_container.rs b/src/di_container.rs
index e59d8f4..e18fc3a 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -1,18 +1,29 @@
-use std::any::{type_name, TypeId};
+use std::any::type_name;
use std::marker::PhantomData;
-use std::rc::Rc;
-use ahash::AHashMap;
-use error_stack::{Report, ResultExt};
+use error_stack::{report, Report, ResultExt};
#[cfg(feature = "factory")]
use crate::castable_factory::CastableFactory;
+use crate::di_container_binding_map::DIContainerBindingMap;
use crate::errors::di_container::{BindingBuilderError, DIContainerError};
use crate::interfaces::injectable::Injectable;
+use crate::libs::intertrait::cast::error::CastError;
use crate::libs::intertrait::cast::{CastBox, CastRc};
-use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider};
+use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};
use crate::ptr::{SingletonPtr, TransientPtr};
+fn unable_to_cast_binding<Interface>(err: Report<CastError>) -> Report<DIContainerError>
+where
+ Interface: 'static + ?Sized,
+{
+ err.change_context(DIContainerError)
+ .attach_printable(format!(
+ "Unable to cast binding for interface '{}'",
+ type_name::<Interface>()
+ ))
+}
+
/// Binding builder for type `Interface` inside a [`DIContainer`].
pub struct BindingBuilder<'di_container_lt, Interface>
where
@@ -40,12 +51,9 @@ where
where
Implementation: Injectable,
{
- let interface_typeid = TypeId::of::<Interface>();
-
- self.di_container.bindings.insert(
- interface_typeid,
- Rc::new(TransientTypeProvider::<Implementation>::new()),
- );
+ self.di_container
+ .bindings
+ .set::<Interface>(Box::new(TransientTypeProvider::<Implementation>::new()));
}
/// Creates a binding of type `Interface` to a new singleton of type `Implementation`
@@ -59,8 +67,6 @@ where
where
Implementation: Injectable,
{
- let interface_typeid = TypeId::of::<Interface>();
-
let singleton: SingletonPtr<Implementation> = SingletonPtr::from(
Implementation::resolve(self.di_container)
.change_context(BindingBuilderError)?,
@@ -68,7 +74,7 @@ where
self.di_container
.bindings
- .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton)));
+ .set::<Interface>(Box::new(SingletonProvider::new(singleton)));
Ok(())
}
@@ -84,16 +90,13 @@ where
Return: 'static + ?Sized,
Interface: crate::interfaces::factory::IFactory<Args, Return>,
{
- let interface_typeid = TypeId::of::<Interface>();
-
let factory_impl = CastableFactory::new(factory_func);
- self.di_container.bindings.insert(
- interface_typeid,
- Rc::new(crate::provider::FactoryProvider::new(
- crate::ptr::FactoryPtr::new(factory_impl),
+ self.di_container.bindings.set::<Interface>(Box::new(
+ crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new(
+ factory_impl,
)),
- );
+ ));
}
}
@@ -144,7 +147,7 @@ where
/// ```
pub struct DIContainer
{
- bindings: AHashMap<TypeId, Rc<dyn IProvider>>,
+ bindings: DIContainerBindingMap,
}
impl DIContainer
@@ -154,7 +157,7 @@ impl DIContainer
pub fn new() -> Self
{
Self {
- bindings: AHashMap::new(),
+ bindings: DIContainerBindingMap::new(),
}
}
@@ -180,39 +183,18 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let interface_typeid = TypeId::of::<Interface>();
-
- let interface_name = type_name::<Interface>();
+ let binding_providable = self.get_binding_providable::<Interface>()?;
- let binding = self.bindings.get(&interface_typeid).ok_or_else(|| {
- Report::new(DIContainerError)
- .attach_printable(format!("No binding exists for {}", interface_name))
- })?;
-
- let binding_providable = binding
- .provide(self)
- .change_context(DIContainerError)
- .attach_printable(format!(
- "Failed to resolve binding for interface {}",
- interface_name
- ))?;
-
- match binding_providable {
- Providable::Transient(binding_injectable) => {
- let interface_result = binding_injectable.cast::<Interface>();
-
- match interface_result {
- Ok(interface) => Ok(interface),
- Err(_) => Err(Report::new(DIContainerError).attach_printable(
- format!("Unable to cast binding for {}", interface_name),
- )),
- }
- }
- _ => Err(Report::new(DIContainerError).attach_printable(format!(
- "Binding for {} is not injectable",
- interface_name
- ))),
+ if let Providable::Transient(binding_transient) = binding_providable {
+ return binding_transient
+ .cast::<Interface>()
+ .map_err(unable_to_cast_binding::<Interface>);
}
+
+ Err(report!(DIContainerError).attach_printable(format!(
+ "Binding for interface '{}' is not transient",
+ type_name::<Interface>()
+ )))
}
/// Returns the singleton instance bound with `Interface`.
@@ -229,39 +211,18 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let interface_typeid = TypeId::of::<Interface>();
-
- let interface_name = type_name::<Interface>();
+ let binding_providable = self.get_binding_providable::<Interface>()?;
- let binding = self.bindings.get(&interface_typeid).ok_or_else(|| {
- Report::new(DIContainerError)
- .attach_printable(format!("No binding exists for {}", interface_name))
- })?;
-
- let binding_providable = binding
- .provide(self)
- .change_context(DIContainerError)
- .attach_printable(format!(
- "Failed to resolve binding for interface {}",
- interface_name
- ))?;
-
- match binding_providable {
- Providable::Singleton(binding_singleton) => {
- let interface_result = binding_singleton.cast::<Interface>();
-
- match interface_result {
- Ok(interface) => Ok(interface),
- Err(_) => Err(Report::new(DIContainerError).attach_printable(
- format!("Unable to cast binding for {}", interface_name),
- )),
- }
- }
- _ => Err(Report::new(DIContainerError).attach_printable(format!(
- "Binding for {} is not a singleton",
- interface_name
- ))),
+ if let Providable::Singleton(binding_singleton) = binding_providable {
+ return binding_singleton
+ .cast::<Interface>()
+ .map_err(unable_to_cast_binding::<Interface>);
}
+
+ Err(report!(DIContainerError).attach_printable(format!(
+ "Binding for interface '{}' is not a singleton",
+ type_name::<Interface>()
+ )))
}
/// Returns the factory bound with factory type `Interface`.
@@ -279,39 +240,34 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let interface_typeid = TypeId::of::<Interface>();
+ let binding_providable = self.get_binding_providable::<Interface>()?;
- let interface_name = type_name::<Interface>();
+ if let Providable::Factory(binding_factory) = binding_providable {
+ return binding_factory
+ .cast::<Interface>()
+ .map_err(unable_to_cast_binding::<Interface>);
+ }
- let binding = self.bindings.get(&interface_typeid).ok_or_else(|| {
- Report::new(DIContainerError)
- .attach_printable(format!("No binding exists for {}", interface_name))
- })?;
+ Err(report!(DIContainerError).attach_printable(format!(
+ "Binding for interface '{}' is not a factory",
+ type_name::<Interface>()
+ )))
+ }
- let binding_providable = binding
+ fn get_binding_providable<Interface>(
+ &self,
+ ) -> error_stack::Result<Providable, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ self.bindings
+ .get::<Interface>()?
.provide(self)
.change_context(DIContainerError)
.attach_printable(format!(
- "Failed to resolve binding for interface {}",
- interface_name
- ))?;
-
- match binding_providable {
- Providable::Factory(binding_factory) => {
- let factory_result = binding_factory.cast::<Interface>();
-
- match factory_result {
- Ok(factory) => Ok(factory),
- Err(_) => Err(Report::new(DIContainerError).attach_printable(
- format!("Unable to cast binding for {}", interface_name),
- )),
- }
- }
- _ => Err(Report::new(DIContainerError).attach_printable(format!(
- "Binding for {} is not a factory",
- interface_name
- ))),
- }
+ "Failed to resolve binding for interface '{}'",
+ type_name::<Interface>()
+ ))
}
}
@@ -332,6 +288,7 @@ mod tests
use super::*;
use crate::errors::injectable::ResolveError;
+ use crate::provider::IProvider;
use crate::ptr::TransientPtr;
#[test]
@@ -376,11 +333,11 @@ mod tests
let mut di_container: DIContainer = DIContainer::new();
- assert_eq!(di_container.bindings.len(), 0);
+ assert_eq!(di_container.bindings.count(), 0);
di_container.bind::<dyn IUserManager>().to::<UserManager>();
- assert_eq!(di_container.bindings.len(), 1);
+ assert_eq!(di_container.bindings.count(), 1);
}
#[test]
@@ -425,13 +382,13 @@ mod tests
let mut di_container: DIContainer = DIContainer::new();
- assert_eq!(di_container.bindings.len(), 0);
+ assert_eq!(di_container.bindings.count(), 0);
di_container
.bind::<dyn IUserManager>()
.to_singleton::<UserManager>()?;
- assert_eq!(di_container.bindings.len(), 1);
+ assert_eq!(di_container.bindings.count(), 1);
Ok(())
}
@@ -475,7 +432,7 @@ mod tests
let mut di_container: DIContainer = DIContainer::new();
- assert_eq!(di_container.bindings.len(), 0);
+ assert_eq!(di_container.bindings.count(), 0);
di_container.bind::<IUserManagerFactory>().to_factory(&|| {
let user_manager: TransientPtr<dyn IUserManager> =
@@ -484,7 +441,7 @@ mod tests
user_manager
});
- assert_eq!(di_container.bindings.len(), 1);
+ assert_eq!(di_container.bindings.count(), 1);
}
#[test]
@@ -546,7 +503,7 @@ mod tests
di_container
.bindings
- .insert(TypeId::of::<dyn IUserManager>(), Rc::new(mock_provider));
+ .set::<dyn IUserManager>(Box::new(mock_provider));
di_container.get::<dyn IUserManager>()?;
@@ -635,7 +592,7 @@ mod tests
di_container
.bindings
- .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider));
+ .set::<dyn INumber>(Box::new(mock_provider));
let first_number_rc = di_container.get_singleton::<dyn INumber>()?;
@@ -723,7 +680,7 @@ mod tests
di_container
.bindings
- .insert(TypeId::of::<IUserManagerFactory>(), Rc::new(mock_provider));
+ .set::<IUserManagerFactory>(Box::new(mock_provider));
di_container.get_factory::<IUserManagerFactory>()?;