aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-07-27 18:05:34 +0200
committerHampusM <hampus@hampusmat.com>2022-07-31 12:17:51 +0200
commit3388f857b32cf1893d7b54582c8fd16e4965550b (patch)
tree2d78d6ca563b071cc2d6e4b3e80c05d93e575737 /src
parent3fbf26181f1b4b9e594debb103fd347bd93240ea (diff)
feat: implement binding singletons
Diffstat (limited to 'src')
-rw-r--r--src/di_container.rs272
-rw-r--r--src/errors/di_container.rs13
-rw-r--r--src/provider.rs43
-rw-r--r--src/ptr.rs2
4 files changed, 297 insertions, 33 deletions
diff --git a/src/di_container.rs b/src/di_container.rs
index d1c757d..e59d8f4 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -7,11 +7,11 @@ use error_stack::{Report, ResultExt};
#[cfg(feature = "factory")]
use crate::castable_factory::CastableFactory;
-use crate::errors::di_container::DIContainerError;
+use crate::errors::di_container::{BindingBuilderError, DIContainerError};
use crate::interfaces::injectable::Injectable;
-use crate::libs::intertrait::cast::CastBox;
-use crate::provider::{IProvider, InjectableTypeProvider, Providable};
-use crate::ptr::TransientPtr;
+use crate::libs::intertrait::cast::{CastBox, CastRc};
+use crate::provider::{IProvider, Providable, SingletonProvider, TransientTypeProvider};
+use crate::ptr::{SingletonPtr, TransientPtr};
/// Binding builder for type `Interface` inside a [`DIContainer`].
pub struct BindingBuilder<'di_container_lt, Interface>
@@ -44,10 +44,35 @@ where
self.di_container.bindings.insert(
interface_typeid,
- Rc::new(InjectableTypeProvider::<Implementation>::new()),
+ Rc::new(TransientTypeProvider::<Implementation>::new()),
);
}
+ /// Creates a binding of type `Interface` to a new singleton of type `Implementation`
+ /// inside of the associated [`DIContainer`].
+ ///
+ /// # Errors
+ /// Will return Err if creating the singleton fails.
+ pub fn to_singleton<Implementation>(
+ &mut self,
+ ) -> error_stack::Result<(), BindingBuilderError>
+ where
+ Implementation: Injectable,
+ {
+ let interface_typeid = TypeId::of::<Interface>();
+
+ let singleton: SingletonPtr<Implementation> = SingletonPtr::from(
+ Implementation::resolve(self.di_container)
+ .change_context(BindingBuilderError)?,
+ );
+
+ self.di_container
+ .bindings
+ .insert(interface_typeid, Rc::new(SingletonProvider::new(singleton)));
+
+ Ok(())
+ }
+
/// Creates a binding of factory type `Interface` to a factory inside of the
/// associated [`DIContainer`].
#[cfg(feature = "factory")]
@@ -173,21 +198,69 @@ impl DIContainer
))?;
match binding_providable {
- Providable::Injectable(binding_injectable) => {
- let interface_box_result = binding_injectable.cast::<Interface>();
+ Providable::Transient(binding_injectable) => {
+ let interface_result = binding_injectable.cast::<Interface>();
- match interface_box_result {
- Ok(interface_box) => Ok(interface_box),
+ match interface_result {
+ Ok(interface) => Ok(interface),
Err(_) => Err(Report::new(DIContainerError).attach_printable(
format!("Unable to cast binding for {}", interface_name),
)),
}
}
- Providable::Factory(_) => Err(Report::new(DIContainerError)
- .attach_printable(format!(
- "Binding for {} is not injectable",
- interface_name
- ))),
+ _ => Err(Report::new(DIContainerError).attach_printable(format!(
+ "Binding for {} is not injectable",
+ interface_name
+ ))),
+ }
+ }
+
+ /// Returns the singleton instance bound with `Interface`.
+ ///
+ /// # Errors
+ /// Will return `Err` if:
+ /// - No binding for `Interface` exists
+ /// - Resolving the binding for `Interface` fails
+ /// - Casting the binding for `Interface` fails
+ /// - The binding for `Interface` is not a singleton
+ pub fn get_singleton<Interface>(
+ &self,
+ ) -> error_stack::Result<SingletonPtr<Interface>, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ let interface_typeid = TypeId::of::<Interface>();
+
+ let interface_name = type_name::<Interface>();
+
+ let binding = self.bindings.get(&interface_typeid).ok_or_else(|| {
+ Report::new(DIContainerError)
+ .attach_printable(format!("No binding exists for {}", interface_name))
+ })?;
+
+ let binding_providable = binding
+ .provide(self)
+ .change_context(DIContainerError)
+ .attach_printable(format!(
+ "Failed to resolve binding for interface {}",
+ interface_name
+ ))?;
+
+ match binding_providable {
+ Providable::Singleton(binding_singleton) => {
+ let interface_result = binding_singleton.cast::<Interface>();
+
+ match interface_result {
+ Ok(interface) => Ok(interface),
+ Err(_) => Err(Report::new(DIContainerError).attach_printable(
+ format!("Unable to cast binding for {}", interface_name),
+ )),
+ }
+ }
+ _ => Err(Report::new(DIContainerError).attach_printable(format!(
+ "Binding for {} is not a singleton",
+ interface_name
+ ))),
}
}
@@ -225,22 +298,19 @@ impl DIContainer
match binding_providable {
Providable::Factory(binding_factory) => {
- use crate::libs::intertrait::cast::CastRc;
+ let factory_result = binding_factory.cast::<Interface>();
- let factory_box_result = binding_factory.cast::<Interface>();
-
- match factory_box_result {
- Ok(interface_box) => Ok(interface_box),
+ match factory_result {
+ Ok(factory) => Ok(factory),
Err(_) => Err(Report::new(DIContainerError).attach_printable(
format!("Unable to cast binding for {}", interface_name),
)),
}
}
- Providable::Injectable(_) => Err(Report::new(DIContainerError)
- .attach_printable(format!(
- "Binding for {} is not a factory",
- interface_name
- ))),
+ _ => Err(Report::new(DIContainerError).attach_printable(format!(
+ "Binding for {} is not a factory",
+ interface_name
+ ))),
}
}
}
@@ -256,6 +326,8 @@ impl Default for DIContainer
#[cfg(test)]
mod tests
{
+ use std::fmt::Debug;
+
use mockall::mock;
use super::*;
@@ -312,6 +384,59 @@ mod tests
}
#[test]
+ fn can_bind_to_singleton() -> error_stack::Result<(), BindingBuilderError>
+ {
+ trait IUserManager
+ {
+ fn add_user(&self, user_id: i128);
+
+ fn remove_user(&self, user_id: i128);
+ }
+
+ struct UserManager {}
+
+ impl IUserManager for UserManager
+ {
+ fn add_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+
+ fn remove_user(&self, _user_id: i128)
+ {
+ // ...
+ }
+ }
+
+ impl Injectable for UserManager
+ {
+ fn resolve(
+ _di_container: &DIContainer,
+ ) -> error_stack::Result<
+ TransientPtr<Self>,
+ crate::errors::injectable::ResolveError,
+ >
+ where
+ Self: Sized,
+ {
+ Ok(TransientPtr::new(Self {}))
+ }
+ }
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ assert_eq!(di_container.bindings.len(), 0);
+
+ di_container
+ .bind::<dyn IUserManager>()
+ .to_singleton::<UserManager>()?;
+
+ assert_eq!(di_container.bindings.len(), 1);
+
+ Ok(())
+ }
+
+ #[test]
#[cfg(feature = "factory")]
fn can_bind_to_factory()
{
@@ -416,9 +541,7 @@ mod tests
let mut mock_provider = MockProvider::new();
mock_provider.expect_provide().returning(|_| {
- Ok(Providable::Injectable(
- TransientPtr::new(UserManager::new()),
- ))
+ Ok(Providable::Transient(TransientPtr::new(UserManager::new())))
});
di_container
@@ -431,6 +554,101 @@ mod tests
}
#[test]
+ fn can_get_singleton() -> error_stack::Result<(), DIContainerError>
+ {
+ trait INumber
+ {
+ fn get(&self) -> i32;
+
+ fn set(&mut self, number: i32);
+ }
+
+ impl PartialEq for dyn INumber
+ {
+ fn eq(&self, other: &Self) -> bool
+ {
+ self.get() == other.get()
+ }
+ }
+
+ impl Debug for dyn INumber
+ {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ f.write_str(format!("{}", self.get()).as_str())
+ }
+ }
+
+ struct Number
+ {
+ num: i32,
+ }
+
+ use crate as syrette;
+ use crate::injectable;
+
+ #[injectable(INumber)]
+ impl Number
+ {
+ fn new() -> Self
+ {
+ Self { num: 0 }
+ }
+ }
+
+ impl INumber for Number
+ {
+ fn get(&self) -> i32
+ {
+ self.num
+ }
+
+ fn set(&mut self, number: i32)
+ {
+ self.num = number;
+ }
+ }
+
+ mock! {
+ Provider {}
+
+ impl IProvider for Provider
+ {
+ fn provide(
+ &self,
+ di_container: &DIContainer,
+ ) -> error_stack::Result<Providable, ResolveError>;
+ }
+ }
+
+ let mut di_container: DIContainer = DIContainer::new();
+
+ let mut mock_provider = MockProvider::new();
+
+ let mut singleton = SingletonPtr::new(Number::new());
+
+ SingletonPtr::get_mut(&mut singleton).unwrap().set(2820);
+
+ mock_provider
+ .expect_provide()
+ .returning_st(move |_| Ok(Providable::Singleton(singleton.clone())));
+
+ di_container
+ .bindings
+ .insert(TypeId::of::<dyn INumber>(), Rc::new(mock_provider));
+
+ let first_number_rc = di_container.get_singleton::<dyn INumber>()?;
+
+ assert_eq!(first_number_rc.get(), 2820);
+
+ let second_number_rc = di_container.get_singleton::<dyn INumber>()?;
+
+ assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
+
+ Ok(())
+ }
+
+ #[test]
#[cfg(feature = "factory")]
fn can_get_factory() -> error_stack::Result<(), DIContainerError>
{
diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs
index f2b8add..3b8c717 100644
--- a/src/errors/di_container.rs
+++ b/src/errors/di_container.rs
@@ -15,3 +15,16 @@ impl Display for DIContainerError
}
impl Context for DIContainerError {}
+
+#[derive(Debug)]
+pub struct BindingBuilderError;
+
+impl Display for BindingBuilderError
+{
+ fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result
+ {
+ fmt.write_str("A binding builder error has occurred")
+ }
+}
+
+impl Context for BindingBuilderError {}
diff --git a/src/provider.rs b/src/provider.rs
index eb5b5d8..2e832f8 100644
--- a/src/provider.rs
+++ b/src/provider.rs
@@ -4,14 +4,15 @@ use std::marker::PhantomData;
use crate::errors::injectable::ResolveError;
use crate::interfaces::any_factory::AnyFactory;
use crate::interfaces::injectable::Injectable;
-use crate::ptr::{FactoryPtr, TransientPtr};
+use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr};
use crate::DIContainer;
extern crate error_stack;
pub enum Providable
{
- Injectable(TransientPtr<dyn Injectable>),
+ Transient(TransientPtr<dyn Injectable>),
+ Singleton(SingletonPtr<dyn Injectable>),
#[allow(dead_code)]
Factory(FactoryPtr<dyn AnyFactory>),
}
@@ -24,14 +25,14 @@ pub trait IProvider
) -> error_stack::Result<Providable, ResolveError>;
}
-pub struct InjectableTypeProvider<InjectableType>
+pub struct TransientTypeProvider<InjectableType>
where
InjectableType: Injectable,
{
injectable_phantom: PhantomData<InjectableType>,
}
-impl<InjectableType> InjectableTypeProvider<InjectableType>
+impl<InjectableType> TransientTypeProvider<InjectableType>
where
InjectableType: Injectable,
{
@@ -43,7 +44,7 @@ where
}
}
-impl<InjectableType> IProvider for InjectableTypeProvider<InjectableType>
+impl<InjectableType> IProvider for TransientTypeProvider<InjectableType>
where
InjectableType: Injectable,
{
@@ -52,12 +53,42 @@ where
di_container: &DIContainer,
) -> error_stack::Result<Providable, ResolveError>
{
- Ok(Providable::Injectable(InjectableType::resolve(
+ Ok(Providable::Transient(InjectableType::resolve(
di_container,
)?))
}
}
+pub struct SingletonProvider<InjectableType>
+where
+ InjectableType: Injectable,
+{
+ singleton: SingletonPtr<InjectableType>,
+}
+
+impl<InjectableType> SingletonProvider<InjectableType>
+where
+ InjectableType: Injectable,
+{
+ pub fn new(singleton: SingletonPtr<InjectableType>) -> Self
+ {
+ Self { singleton }
+ }
+}
+
+impl<InjectableType> IProvider for SingletonProvider<InjectableType>
+where
+ InjectableType: Injectable,
+{
+ fn provide(
+ &self,
+ _di_container: &DIContainer,
+ ) -> error_stack::Result<Providable, ResolveError>
+ {
+ Ok(Providable::Singleton(self.singleton.clone()))
+ }
+}
+
#[cfg(feature = "factory")]
pub struct FactoryProvider
{
diff --git a/src/ptr.rs b/src/ptr.rs
index 62445fd..00c74f4 100644
--- a/src/ptr.rs
+++ b/src/ptr.rs
@@ -3,4 +3,6 @@ use std::rc::Rc;
pub type TransientPtr<Interface> = Box<Interface>;
+pub type SingletonPtr<Interface> = Rc<Interface>;
+
pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>;