aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-08-02 14:31:31 +0200
committerHampusM <hampus@hampusmat.com>2022-08-02 14:31:31 +0200
commit826592eac2601e9fcd5aabb17482b4816ed7ab88 (patch)
tree3c3ba5436ca76e0738e4cc8deefef3025fd5d4bd
parent163cd3cedd398f5676edbcb3249dd958d3e97aca (diff)
feat: add detection and prevention of circular dependencies
-rw-r--r--README.md1
-rw-r--r--macros/src/injectable_impl.rs97
-rw-r--r--src/di_container.rs88
-rw-r--r--src/interfaces/injectable.rs1
-rw-r--r--src/provider.rs7
5 files changed, 127 insertions, 67 deletions
diff --git a/README.md b/README.md
index 4c9eec0..4febe5b 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,7 @@ From the [syrette Wikipedia article](https://en.wikipedia.org/wiki/Syrette).
- Enforces the use of interface traits
- Supports generic implementations & generic interface traits
- Binding singletons
+- Detection and prevention of circular dependencies
## Optional features
- `factory`. Binding factories (Rust nightly required)
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index 227a8c6..b24749c 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -49,7 +49,14 @@ impl InjectableImpl
let di_container_var: Ident = parse_str(DI_CONTAINER_VAR_NAME).unwrap();
- let get_dependencies = Self::_create_get_dependencies(dependency_types);
+ let get_dep_method_names = Self::_create_get_dep_method_names(dependency_types);
+
+ let get_dependencies = get_dep_method_names.iter().map(|get_dep_method_name| {
+ parse_str::<ExprMethodCall>(
+ format!("{}(dependency_history.clone())", get_dep_method_name).as_str(),
+ )
+ .unwrap()
+ });
let maybe_doc_hidden = if no_doc_hidden {
quote! {}
@@ -65,20 +72,52 @@ impl InjectableImpl
#maybe_doc_hidden
impl #generics syrette::interfaces::injectable::Injectable for #self_type {
fn resolve(
- #di_container_var: &syrette::DIContainer
+ #di_container_var: &syrette::DIContainer,
+ mut dependency_history: Vec<&'static str>,
) -> syrette::libs::error_stack::Result<
syrette::ptr::TransientPtr<Self>,
syrette::errors::injectable::ResolveError>
{
- use syrette::libs::error_stack::ResultExt;
+ use syrette::libs::error_stack::{ResultExt, report};
+ use syrette::errors::injectable::ResolveError;
+
+ let self_type_name = std::any::type_name::<#self_type>();
+
+ if dependency_history.contains(&self_type_name) {
+ dependency_history.push(self_type_name);
+
+ let dependency_trace = dependency_history
+ .iter()
+ .map(|dep| {
+ if dep == &self_type_name {
+ format!("\x1b[1m{}\x1b[22m", dep)
+ } else {
+ dep.to_string()
+ }
+ })
+ .collect::<Vec<_>>()
+ .join(" -> ");
+
+ return Err(
+ report!(ResolveError)
+ .attach_printable(
+ format!(
+ "Detected circular dependencies. {}",
+ dependency_trace.clone(),
+ )
+ )
+ );
+ }
+
+ dependency_history.push(self_type_name);
return Ok(syrette::ptr::TransientPtr::new(Self::new(
#(#get_dependencies
- .change_context(syrette::errors::injectable::ResolveError)
+ .change_context(ResolveError)
.attach_printable(
format!(
"Unable to resolve a dependency of {}",
- std::any::type_name::<#self_type>()
+ self_type_name
)
)?
),*
@@ -88,48 +127,34 @@ impl InjectableImpl
}
}
- fn _create_get_dependencies(
- dependency_types: &[DependencyType],
- ) -> Vec<ExprMethodCall>
+ fn _create_get_dep_method_names(dependency_types: &[DependencyType]) -> Vec<String>
{
dependency_types
.iter()
.filter_map(|dep_type| match &dep_type.interface {
- Type::TraitObject(dep_type_trait) => Some(
- parse_str(
- format!(
- "{}.get{}::<{}>()",
- DI_CONTAINER_VAR_NAME,
- if dep_type.ptr == "SingletonPtr" {
- "_singleton"
- } else {
- ""
- },
- dep_type_trait.to_token_stream()
- )
- .as_str(),
- )
- .unwrap(),
- ),
+ Type::TraitObject(dep_type_trait) => Some(format!(
+ "{}.get_{}::<{}>",
+ DI_CONTAINER_VAR_NAME,
+ if dep_type.ptr == "SingletonPtr" {
+ "singleton_with_history"
+ } else {
+ "with_history"
+ },
+ dep_type_trait.to_token_stream()
+ )),
Type::Path(dep_type_path) => {
let dep_type_path_str = Self::_path_to_string(&dep_type_path.path);
let get_method_name = if dep_type_path_str.ends_with("Factory") {
- "get_factory"
+ "factory"
} else {
- "get"
+ "with_history"
};
- Some(
- parse_str(
- format!(
- "{}.{}::<{}>()",
- DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str
- )
- .as_str(),
- )
- .unwrap(),
- )
+ Some(format!(
+ "get_{}.{}::<{}>",
+ DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str
+ ))
}
&_ => None,
})
diff --git a/src/di_container.rs b/src/di_container.rs
index 9509bd8..22ae3ba 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -113,7 +113,7 @@ where
Implementation: Injectable,
{
let singleton: SingletonPtr<Implementation> = SingletonPtr::from(
- Implementation::resolve(self.di_container)
+ Implementation::resolve(self.di_container, Vec::new())
.change_context(BindingBuilderError)?,
);
@@ -186,18 +186,7 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let binding_providable = self.get_binding_providable::<Interface>()?;
-
- if let Providable::Transient(binding_transient) = binding_providable {
- return binding_transient
- .cast::<Interface>()
- .map_err(unable_to_cast_binding::<Interface>);
- }
-
- Err(report!(DIContainerError).attach_printable(format!(
- "Binding for interface '{}' is not transient",
- type_name::<Interface>()
- )))
+ self.get_with_history::<Interface>(Vec::new())
}
/// Returns the singleton instance bound with `Interface`.
@@ -214,18 +203,7 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let binding_providable = self.get_binding_providable::<Interface>()?;
-
- if let Providable::Singleton(binding_singleton) = binding_providable {
- return binding_singleton
- .cast::<Interface>()
- .map_err(unable_to_cast_binding::<Interface>);
- }
-
- Err(report!(DIContainerError).attach_printable(format!(
- "Binding for interface '{}' is not a singleton",
- type_name::<Interface>()
- )))
+ self.get_singleton_with_history(Vec::new())
}
/// Returns the factory bound with factory type `Interface`.
@@ -245,7 +223,7 @@ impl DIContainer
where
Interface: 'static + ?Sized,
{
- let binding_providable = self.get_binding_providable::<Interface>()?;
+ let binding_providable = self.get_binding_providable::<Interface>(Vec::new())?;
if let Providable::Factory(binding_factory) = binding_providable {
return binding_factory
@@ -259,15 +237,62 @@ impl DIContainer
)))
}
+ #[doc(hidden)]
+ pub fn get_with_history<Interface>(
+ &self,
+ dependency_history: Vec<&'static str>,
+ ) -> error_stack::Result<TransientPtr<Interface>, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ let binding_providable =
+ self.get_binding_providable::<Interface>(dependency_history)?;
+
+ if let Providable::Transient(binding_transient) = binding_providable {
+ return binding_transient
+ .cast::<Interface>()
+ .map_err(unable_to_cast_binding::<Interface>);
+ }
+
+ Err(report!(DIContainerError).attach_printable(format!(
+ "Binding for interface '{}' is not transient",
+ type_name::<Interface>()
+ )))
+ }
+
+ #[doc(hidden)]
+ pub fn get_singleton_with_history<Interface>(
+ &self,
+ dependency_history: Vec<&'static str>,
+ ) -> error_stack::Result<SingletonPtr<Interface>, DIContainerError>
+ where
+ Interface: 'static + ?Sized,
+ {
+ let binding_providable =
+ self.get_binding_providable::<Interface>(dependency_history)?;
+
+ if let Providable::Singleton(binding_singleton) = binding_providable {
+ return binding_singleton
+ .cast::<Interface>()
+ .map_err(unable_to_cast_binding::<Interface>);
+ }
+
+ Err(report!(DIContainerError).attach_printable(format!(
+ "Binding for interface '{}' is not a singleton",
+ type_name::<Interface>()
+ )))
+ }
+
fn get_binding_providable<Interface>(
&self,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, DIContainerError>
where
Interface: 'static + ?Sized,
{
self.bindings
.get::<Interface>()?
- .provide(self)
+ .provide(self, dependency_history)
.change_context(DIContainerError)
.attach_printable(format!(
"Failed to resolve binding for interface '{}'",
@@ -325,6 +350,7 @@ mod tests
{
fn resolve(
_di_container: &DIContainer,
+ _dependency_history: Vec<&'static str>,
) -> error_stack::Result<
TransientPtr<Self>,
crate::errors::injectable::ResolveError,
@@ -374,6 +400,7 @@ mod tests
{
fn resolve(
_di_container: &DIContainer,
+ _dependency_history: Vec<&'static str>,
) -> error_stack::Result<
TransientPtr<Self>,
crate::errors::injectable::ResolveError,
@@ -494,6 +521,7 @@ mod tests
fn provide(
&self,
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>;
}
}
@@ -502,7 +530,7 @@ mod tests
let mut mock_provider = MockProvider::new();
- mock_provider.expect_provide().returning(|_| {
+ mock_provider.expect_provide().returning(|_, _| {
Ok(Providable::Transient(TransientPtr::new(UserManager::new())))
});
@@ -579,6 +607,7 @@ mod tests
fn provide(
&self,
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>;
}
}
@@ -593,7 +622,7 @@ mod tests
mock_provider
.expect_provide()
- .returning_st(move |_| Ok(Providable::Singleton(singleton.clone())));
+ .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone())));
di_container
.bindings
@@ -664,6 +693,7 @@ mod tests
fn provide(
&self,
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>;
}
}
diff --git a/src/interfaces/injectable.rs b/src/interfaces/injectable.rs
index 31cd21b..e6e4ced 100644
--- a/src/interfaces/injectable.rs
+++ b/src/interfaces/injectable.rs
@@ -13,6 +13,7 @@ pub trait Injectable: CastFrom
/// Will return `Err` if resolving the dependencies fails.
fn resolve(
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<TransientPtr<Self>, ResolveError>
where
Self: Sized;
diff --git a/src/provider.rs b/src/provider.rs
index 2e832f8..e12a12a 100644
--- a/src/provider.rs
+++ b/src/provider.rs
@@ -7,8 +7,6 @@ use crate::interfaces::injectable::Injectable;
use crate::ptr::{FactoryPtr, SingletonPtr, TransientPtr};
use crate::DIContainer;
-extern crate error_stack;
-
pub enum Providable
{
Transient(TransientPtr<dyn Injectable>),
@@ -22,6 +20,7 @@ pub trait IProvider
fn provide(
&self,
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>;
}
@@ -51,10 +50,12 @@ where
fn provide(
&self,
di_container: &DIContainer,
+ dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>
{
Ok(Providable::Transient(InjectableType::resolve(
di_container,
+ dependency_history,
)?))
}
}
@@ -83,6 +84,7 @@ where
fn provide(
&self,
_di_container: &DIContainer,
+ _dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>
{
Ok(Providable::Singleton(self.singleton.clone()))
@@ -110,6 +112,7 @@ impl IProvider for FactoryProvider
fn provide(
&self,
_di_container: &DIContainer,
+ _dependency_history: Vec<&'static str>,
) -> error_stack::Result<Providable, ResolveError>
{
Ok(Providable::Factory(self.factory.clone()))