From 826592eac2601e9fcd5aabb17482b4816ed7ab88 Mon Sep 17 00:00:00 2001 From: HampusM Date: Tue, 2 Aug 2022 14:31:31 +0200 Subject: feat: add detection and prevention of circular dependencies --- README.md | 1 + macros/src/injectable_impl.rs | 97 +++++++++++++++++++++++++++---------------- src/di_container.rs | 88 ++++++++++++++++++++++++++------------- src/interfaces/injectable.rs | 1 + src/provider.rs | 7 +++- 5 files changed, 127 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 4c9eec0..4febe5b 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ From the [syrette Wikipedia article](https://en.wikipedia.org/wiki/Syrette). - Enforces the use of interface traits - Supports generic implementations & generic interface traits - Binding singletons +- Detection and prevention of circular dependencies ## Optional features - `factory`. Binding factories (Rust nightly required) diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 227a8c6..b24749c 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -49,7 +49,14 @@ impl InjectableImpl let di_container_var: Ident = parse_str(DI_CONTAINER_VAR_NAME).unwrap(); - let get_dependencies = Self::_create_get_dependencies(dependency_types); + let get_dep_method_names = Self::_create_get_dep_method_names(dependency_types); + + let get_dependencies = get_dep_method_names.iter().map(|get_dep_method_name| { + parse_str::( + format!("{}(dependency_history.clone())", get_dep_method_name).as_str(), + ) + .unwrap() + }); let maybe_doc_hidden = if no_doc_hidden { quote! {} @@ -65,20 +72,52 @@ impl InjectableImpl #maybe_doc_hidden impl #generics syrette::interfaces::injectable::Injectable for #self_type { fn resolve( - #di_container_var: &syrette::DIContainer + #di_container_var: &syrette::DIContainer, + mut dependency_history: Vec<&'static str>, ) -> syrette::libs::error_stack::Result< syrette::ptr::TransientPtr, syrette::errors::injectable::ResolveError> { - use syrette::libs::error_stack::ResultExt; + use syrette::libs::error_stack::{ResultExt, report}; + use syrette::errors::injectable::ResolveError; + + let self_type_name = std::any::type_name::<#self_type>(); + + if dependency_history.contains(&self_type_name) { + dependency_history.push(self_type_name); + + let dependency_trace = dependency_history + .iter() + .map(|dep| { + if dep == &self_type_name { + format!("\x1b[1m{}\x1b[22m", dep) + } else { + dep.to_string() + } + }) + .collect::>() + .join(" -> "); + + return Err( + report!(ResolveError) + .attach_printable( + format!( + "Detected circular dependencies. {}", + dependency_trace.clone(), + ) + ) + ); + } + + dependency_history.push(self_type_name); return Ok(syrette::ptr::TransientPtr::new(Self::new( #(#get_dependencies - .change_context(syrette::errors::injectable::ResolveError) + .change_context(ResolveError) .attach_printable( format!( "Unable to resolve a dependency of {}", - std::any::type_name::<#self_type>() + self_type_name ) )? ),* @@ -88,48 +127,34 @@ impl InjectableImpl } } - fn _create_get_dependencies( - dependency_types: &[DependencyType], - ) -> Vec + fn _create_get_dep_method_names(dependency_types: &[DependencyType]) -> Vec { dependency_types .iter() .filter_map(|dep_type| match &dep_type.interface { - Type::TraitObject(dep_type_trait) => Some( - parse_str( - format!( - "{}.get{}::<{}>()", - DI_CONTAINER_VAR_NAME, - if dep_type.ptr == "SingletonPtr" { - "_singleton" - } else { - "" - }, - dep_type_trait.to_token_stream() - ) - .as_str(), - ) - .unwrap(), - ), + Type::TraitObject(dep_type_trait) => Some(format!( + "{}.get_{}::<{}>", + DI_CONTAINER_VAR_NAME, + if dep_type.ptr == "SingletonPtr" { + "singleton_with_history" + } else { + "with_history" + }, + dep_type_trait.to_token_stream() + )), Type::Path(dep_type_path) => { let dep_type_path_str = Self::_path_to_string(&dep_type_path.path); let get_method_name = if dep_type_path_str.ends_with("Factory") { - "get_factory" + "factory" } else { - "get" + "with_history" }; - Some( - parse_str( - format!( - "{}.{}::<{}>()", - DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str - ) - .as_str(), - ) - .unwrap(), - ) + Some(format!( + "get_{}.{}::<{}>", + DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str + )) } &_ => None, }) diff --git a/src/di_container.rs b/src/di_container.rs index 9509bd8..22ae3ba 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -113,7 +113,7 @@ where Implementation: Injectable, { let singleton: SingletonPtr = SingletonPtr::from( - Implementation::resolve(self.di_container) + Implementation::resolve(self.di_container, Vec::new()) .change_context(BindingBuilderError)?, ); @@ -186,18 +186,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::()?; - - if let Providable::Transient(binding_transient) = binding_providable { - return binding_transient - .cast::() - .map_err(unable_to_cast_binding::); - } - - Err(report!(DIContainerError).attach_printable(format!( - "Binding for interface '{}' is not transient", - type_name::() - ))) + self.get_with_history::(Vec::new()) } /// Returns the singleton instance bound with `Interface`. @@ -214,18 +203,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::()?; - - if let Providable::Singleton(binding_singleton) = binding_providable { - return binding_singleton - .cast::() - .map_err(unable_to_cast_binding::); - } - - Err(report!(DIContainerError).attach_printable(format!( - "Binding for interface '{}' is not a singleton", - type_name::() - ))) + self.get_singleton_with_history(Vec::new()) } /// Returns the factory bound with factory type `Interface`. @@ -245,7 +223,7 @@ impl DIContainer where Interface: 'static + ?Sized, { - let binding_providable = self.get_binding_providable::()?; + let binding_providable = self.get_binding_providable::(Vec::new())?; if let Providable::Factory(binding_factory) = binding_providable { return binding_factory @@ -259,15 +237,62 @@ impl DIContainer ))) } + #[doc(hidden)] + pub fn get_with_history( + &self, + dependency_history: Vec<&'static str>, + ) -> error_stack::Result, DIContainerError> + where + Interface: 'static + ?Sized, + { + let binding_providable = + self.get_binding_providable::(dependency_history)?; + + if let Providable::Transient(binding_transient) = binding_providable { + return binding_transient + .cast::() + .map_err(unable_to_cast_binding::); + } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not transient", + type_name::() + ))) + } + + #[doc(hidden)] + pub fn get_singleton_with_history( + &self, + dependency_history: Vec<&'static str>, + ) -> error_stack::Result, DIContainerError> + where + Interface: 'static + ?Sized, + { + let binding_providable = + self.get_binding_providable::(dependency_history)?; + + if let Providable::Singleton(binding_singleton) = binding_providable { + return binding_singleton + .cast::() + .map_err(unable_to_cast_binding::); + } + + Err(report!(DIContainerError).attach_printable(format!( + "Binding for interface '{}' is not a singleton", + type_name::() + ))) + } + fn get_binding_providable( &self, + dependency_history: Vec<&'static str>, ) -> error_stack::Result where Interface: 'static + ?Sized, { self.bindings .get::()? - .provide(self) + .provide(self, dependency_history) .change_context(DIContainerError) .attach_printable(format!( "Failed to resolve binding for interface '{}'", @@ -325,6 +350,7 @@ mod tests { fn resolve( _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result< TransientPtr, crate::errors::injectable::ResolveError, @@ -374,6 +400,7 @@ mod tests { fn resolve( _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result< TransientPtr, crate::errors::injectable::ResolveError, @@ -494,6 +521,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result; } } @@ -502,7 +530,7 @@ mod tests let mut mock_provider = MockProvider::new(); - mock_provider.expect_provide().returning(|_| { + mock_provider.expect_provide().returning(|_, _| { Ok(Providable::Transient(TransientPtr::new(UserManager::new()))) }); @@ -579,6 +607,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result; } } @@ -593,7 +622,7 @@ mod tests mock_provider .expect_provide() - .returning_st(move |_| Ok(Providable::Singleton(singleton.clone()))); + .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone()))); di_container .bindings @@ -664,6 +693,7 @@ mod tests fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result; } } diff --git a/src/interfaces/injectable.rs b/src/interfaces/injectable.rs index 31cd21b..e6e4ced 100644 --- a/src/interfaces/injectable.rs +++ b/src/interfaces/injectable.rs @@ -13,6 +13,7 @@ pub trait Injectable: CastFrom /// Will return `Err` if resolving the dependencies fails. fn resolve( di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result, ResolveError> where Self: Sized; diff --git a/src/provider.rs b/src/provider.rs index 2e832f8..e12a12a 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -7,8 +7,6 @@ use crate::interfaces::injectable::Injectable; use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr}; use crate::DIContainer; -extern crate error_stack; - pub enum Providable { Transient(TransientPtr), @@ -22,6 +20,7 @@ pub trait IProvider fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result; } @@ -51,10 +50,12 @@ where fn provide( &self, di_container: &DIContainer, + dependency_history: Vec<&'static str>, ) -> error_stack::Result { Ok(Providable::Transient(InjectableType::resolve( di_container, + dependency_history, )?)) } } @@ -83,6 +84,7 @@ where fn provide( &self, _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result { Ok(Providable::Singleton(self.singleton.clone())) @@ -110,6 +112,7 @@ impl IProvider for FactoryProvider fn provide( &self, _di_container: &DIContainer, + _dependency_history: Vec<&'static str>, ) -> error_stack::Result { Ok(Providable::Factory(self.factory.clone())) -- cgit v1.2.3-18-g5258