diff options
author | HampusM <hampus@hampusmat.com> | 2022-08-02 14:31:31 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-08-02 14:31:31 +0200 |
commit | 826592eac2601e9fcd5aabb17482b4816ed7ab88 (patch) | |
tree | 3c3ba5436ca76e0738e4cc8deefef3025fd5d4bd /src | |
parent | 163cd3cedd398f5676edbcb3249dd958d3e97aca (diff) |
feat: add detection and prevention of circular dependencies
Diffstat (limited to 'src')
-rw-r--r-- | src/di_container.rs | 88 | ||||
-rw-r--r-- | src/interfaces/injectable.rs | 1 | ||||
-rw-r--r-- | src/provider.rs | 7 |
3 files changed, 65 insertions, 31 deletions
diff --git a/src/di_container.rs b/src/di_container.rs index 9509bd8..22ae3ba 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -113,7 +113,7 @@ where Implementation: Injectable, { let singleton: SingletonPtr<Implementation> = SingletonPtr::from( - Implementation::resolve(self.di_container) + Implementation::resolve(self.di_container, Vec::new()) .change_context(BindingBuilderError)?, ); @@ -186,18 +186,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::<Interface>()?; - - if let Providable::Transient(binding_transient) = binding_providable { - return binding_transient - .cast::<Interface>() - .map_err(unable_to_cast_binding::<Interface>); - } - - Err(report!(DIContainerError).attach_printable(format!( - "Binding for interface '{}' is not transient", - type_name::<Interface>() - ))) + self.get_with_history::<Interface>(Vec::new()) } /// Returns the singleton instance bound with `Interface`. @@ -214,18 +203,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::<Interface>()?; - - if let Providable::Singleton(binding_singleton) = binding_providable { - return binding_singleton - .cast::<Interface>() - .map_err(unable_to_cast_binding::<Interface>); - } - - Err(report!(DIContainerError).attach_printable(format!( - "Binding for interface '{}' is not a singleton", - type_name::<Interface>() - ))) + self.get_singleton_with_history(Vec::new()) } /// Returns the factory bound with factory type `Interface`. @@ -245,7 +223,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::<Interface>()?; + let binding_providable = self.get_binding_providable::<Interface>(Vec::new())?; if let Providable::Factory(binding_factory) = binding_providable { return binding_factory @@ -259,15 +237,62 @@ impl DIContainer ))) } + #[doc(hidden)] + pub fn get_with_history<Interface>( + &self, + dependency_history: Vec<&'static str>, + ) -> error_stack::Result<TransientPtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + let binding_providable = + self.get_binding_providable::<Interface>(dependency_history)?; + + if let Providable::Transient(binding_transient) = binding_providable { + return binding_transient + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); + } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not transient", + type_name::<Interface>() + ))) + } + + #[doc(hidden)] + pub fn get_singleton_with_history<Interface>( + &self, + dependency_history: Vec<&'static str>, + ) -> error_stack::Result<SingletonPtr<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + let binding_providable = + self.get_binding_providable::<Interface>(dependency_history)?; + + if let Providable::Singleton(binding_singleton) = binding_providable { + return binding_singleton + .cast::<Interface>() + .map_err(unable_to_cast_binding::<Interface>); + } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a singleton", + type_name::<Interface>() + ))) + } + fn get_binding_providable<Interface>( &self, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, DIContainerError> where Interface: 'static + ?Sized, { self.bindings .get::<Interface>()? - .provide(self) + .provide(self, dependency_history) .change_context(DIContainerError) .attach_printable(format!( "Failed to resolve binding for interface '{}'", @@ -325,6 +350,7 @@ mod tests { fn resolve( _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result< TransientPtr<Self>, crate::errors::injectable::ResolveError, @@ -374,6 +400,7 @@ mod tests { fn resolve( _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result< TransientPtr<Self>, crate::errors::injectable::ResolveError, @@ -494,6 +521,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError>; } } @@ -502,7 +530,7 @@ mod tests let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_| { + mock_provider.expect_provide().returning(|_, _| { Ok(Providable::Transient(TransientPtr::new(UserManager::new()))) }); @@ -579,6 +607,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError>; } } @@ -593,7 +622,7 @@ mod tests mock_provider .expect_provide() - .returning_st(move |_| Ok(Providable::Singleton(singleton.clone()))); + .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone()))); di_container .bindings @@ -664,6 +693,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError>; } } diff --git a/src/interfaces/injectable.rs b/src/interfaces/injectable.rs index 31cd21b..e6e4ced 100644 --- a/src/interfaces/injectable.rs +++ b/src/interfaces/injectable.rs @@ -13,6 +13,7 @@ pub trait Injectable: CastFrom /// Will return `Err` if resolving the dependencies fails. fn resolve( di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<TransientPtr<Self>, ResolveError> where Self: Sized; diff --git a/src/provider.rs b/src/provider.rs index 2e832f8..e12a12a 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -7,8 +7,6 @@ use crate::interfaces::injectable::Injectable; use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr}; use crate::DIContainer; -extern crate error_stack; - pub enum Providable { Transient(TransientPtr<dyn Injectable>), @@ -22,6 +20,7 @@ pub trait IProvider fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError>; } @@ -51,10 +50,12 @@ where fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError> { Ok(Providable::Transient(InjectableType::resolve( di_container, + dependency_history, )?)) } } @@ -83,6 +84,7 @@ where fn provide( &self, _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError> { Ok(Providable::Singleton(self.singleton.clone())) @@ -110,6 +112,7 @@ impl IProvider for FactoryProvider fn provide( &self, _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result<Providable, ResolveError> { Ok(Providable::Factory(self.factory.clone())) |