aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--examples/basic/main.rs4
-rw-r--r--examples/factory/bootstrap.rs18
-rw-r--r--examples/factory/main.rs10
-rw-r--r--examples/generics/main.rs10
-rw-r--r--examples/unbound/main.rs4
-rw-r--r--examples/with-3rd-party/main.rs2
-rw-r--r--macros/src/injectable_impl.rs83
-rw-r--r--src/di_container.rs163
-rw-r--r--src/errors/di_container.rs20
-rw-r--r--src/errors/mod.rs1
-rw-r--r--src/errors/ptr.rs19
-rw-r--r--src/ptr.rs58
13 files changed, 204 insertions, 189 deletions
diff --git a/Cargo.toml b/Cargo.toml
index f1c6885..8d26077 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ ahash = "0.7.6"
thiserror = "1.0.32"
strum = "0.24.1"
strum_macros = "0.24.3"
+paste = "1.0.8"
[dev_dependencies]
mockall = "0.11.1"
diff --git a/examples/basic/main.rs b/examples/basic/main.rs
index 72f07c2..dbc9215 100644
--- a/examples/basic/main.rs
+++ b/examples/basic/main.rs
@@ -18,11 +18,11 @@ fn main() -> Result<(), Box<dyn Error>>
let di_container = bootstrap()?;
- let dog = di_container.get_singleton::<dyn IDog>()?;
+ let dog = di_container.get::<dyn IDog>()?.singleton()?;
dog.woof();
- let human = di_container.get::<dyn IHuman>()?;
+ let human = di_container.get::<dyn IHuman>()?.transient()?;
human.make_pets_make_sounds();
diff --git a/examples/factory/bootstrap.rs b/examples/factory/bootstrap.rs
index b752764..ad8c4d3 100644
--- a/examples/factory/bootstrap.rs
+++ b/examples/factory/bootstrap.rs
@@ -1,3 +1,5 @@
+use std::error::Error;
+
use syrette::ptr::TransientPtr;
use syrette::DIContainer;
@@ -9,24 +11,22 @@ use crate::interfaces::user_manager::IUserManager;
use crate::user::User;
use crate::user_manager::UserManager;
-pub fn bootstrap() -> DIContainer
+pub fn bootstrap() -> Result<DIContainer, Box<dyn Error>>
{
let mut di_container: DIContainer = DIContainer::new();
di_container
.bind::<dyn IUserManager>()
- .to::<UserManager>()
- .unwrap();
+ .to::<UserManager>()?;
- di_container
- .bind::<IUserFactory>()
- .to_factory(&|name, date_of_birth, password| {
+ di_container.bind::<IUserFactory>().to_factory(
+ &|name, date_of_birth, password| {
let user: TransientPtr<dyn IUser> =
TransientPtr::new(User::new(name, date_of_birth, password));
user
- })
- .unwrap();
+ },
+ )?;
- di_container
+ Ok(di_container)
}
diff --git a/examples/factory/main.rs b/examples/factory/main.rs
index bf3d43b..0f1a97b 100644
--- a/examples/factory/main.rs
+++ b/examples/factory/main.rs
@@ -7,21 +7,25 @@ mod interfaces;
mod user;
mod user_manager;
+use std::error::Error;
+
use bootstrap::bootstrap;
use crate::interfaces::user_manager::IUserManager;
-fn main()
+fn main() -> Result<(), Box<dyn Error>>
{
println!("Hello, world!");
- let di_container = bootstrap();
+ let di_container = bootstrap()?;
- let mut user_manager = di_container.get::<dyn IUserManager>().unwrap();
+ let mut user_manager = di_container.get::<dyn IUserManager>()?.transient()?;
user_manager.fill_with_users();
println!("Printing user information");
user_manager.print_users();
+
+ Ok(())
}
diff --git a/examples/generics/main.rs b/examples/generics/main.rs
index 9442641..f491aa0 100644
--- a/examples/generics/main.rs
+++ b/examples/generics/main.rs
@@ -2,18 +2,22 @@ mod bootstrap;
mod interfaces;
mod printer;
+use std::error::Error;
+
use bootstrap::bootstrap;
use interfaces::printer::IPrinter;
-fn main()
+fn main() -> Result<(), Box<dyn Error>>
{
let di_container = bootstrap();
- let string_printer = di_container.get::<dyn IPrinter<String>>().unwrap();
+ let string_printer = di_container.get::<dyn IPrinter<String>>()?.transient()?;
string_printer.print("Hello there".to_string());
- let int_printer = di_container.get::<dyn IPrinter<i32>>().unwrap();
+ let int_printer = di_container.get::<dyn IPrinter<i32>>()?.transient()?;
int_printer.print(2782028);
+
+ Ok(())
}
diff --git a/examples/unbound/main.rs b/examples/unbound/main.rs
index 031a691..e9a8feb 100644
--- a/examples/unbound/main.rs
+++ b/examples/unbound/main.rs
@@ -19,11 +19,11 @@ fn main() -> Result<(), Box<dyn Error>>
let di_container = bootstrap()?;
- let dog = di_container.get_singleton::<dyn IDog>()?;
+ let dog = di_container.get::<dyn IDog>()?.singleton()?;
dog.woof();
- let human = di_container.get::<dyn IHuman>()?;
+ let human = di_container.get::<dyn IHuman>()?.transient()?;
human.make_pets_make_sounds();
diff --git a/examples/with-3rd-party/main.rs b/examples/with-3rd-party/main.rs
index dd4c21f..e48c78f 100644
--- a/examples/with-3rd-party/main.rs
+++ b/examples/with-3rd-party/main.rs
@@ -17,7 +17,7 @@ fn main() -> Result<(), Box<dyn Error>>
let di_container = bootstrap()?;
- let ninja = di_container.get::<dyn INinja>()?;
+ let ninja = di_container.get::<dyn INinja>()?.transient()?;
ninja.throw_shuriken();
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index 62bf057..d74acb3 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -103,12 +103,7 @@ impl InjectableImpl
#maybe_prevent_circular_deps
return Ok(syrette::ptr::TransientPtr::new(Self::new(
- #(#get_dep_method_calls
- .map_err(|err| InjectableError::ResolveFailed {
- reason: Box::new(err),
- affected: self_type_name
- })?
- ),*
+ #(#get_dep_method_calls),*
)));
}
}
@@ -117,53 +112,59 @@ impl InjectableImpl
fn create_get_dep_method_calls(
dependency_types: &[DependencyType],
- ) -> Vec<ExprMethodCall>
+ ) -> Vec<proc_macro2::TokenStream>
{
dependency_types
.iter()
.filter_map(|dep_type| match &dep_type.interface {
- Type::TraitObject(dep_type_trait) => parse_str::<ExprMethodCall>(
- format!(
- "{}.get_{}::<{}>({}.clone())",
- DI_CONTAINER_VAR_NAME,
- if dep_type.ptr == "SingletonPtr" {
- "singleton_with_history"
- } else {
- "with_history"
- },
- dep_type_trait.to_token_stream(),
- DEPENDENCY_HISTORY_VAR_NAME
+ Type::TraitObject(dep_type_trait) => {
+ let method_call = parse_str::<ExprMethodCall>(
+ format!(
+ "{}.get_bound::<{}>({}.clone())",
+ DI_CONTAINER_VAR_NAME,
+ dep_type_trait.to_token_stream(),
+ DEPENDENCY_HISTORY_VAR_NAME
+ )
+ .as_str(),
)
- .as_str(),
- )
- .ok(),
+ .ok()?;
+
+ Some((method_call, dep_type))
+
+ /*
+ */
+ }
Type::Path(dep_type_path) => {
let dep_type_path_str = Self::path_to_string(&dep_type_path.path);
- if dep_type_path_str.ends_with("Factory") {
- parse_str(
- format!(
- "{}.get_factory::<{}>()",
- DI_CONTAINER_VAR_NAME, dep_type_path_str
- )
- .as_str(),
- )
- .ok()
- } else {
- parse_str(
- format!(
- "{}.get_with_history::<{}>({}.clone())",
- DI_CONTAINER_VAR_NAME,
- dep_type_path_str,
- DEPENDENCY_HISTORY_VAR_NAME
- )
- .as_str(),
+ let method_call: ExprMethodCall = parse_str(
+ format!(
+ "{}.get_bound::<{}>({}.clone())",
+ DI_CONTAINER_VAR_NAME,
+ dep_type_path_str,
+ DEPENDENCY_HISTORY_VAR_NAME
)
- .ok()
- }
+ .as_str(),
+ )
+ .ok()?;
+
+ Some((method_call, dep_type))
}
&_ => None,
})
+ .map(|(method_call, dep_type)| {
+ let ptr_name = dep_type.ptr.to_string();
+
+ let to_ptr =
+ format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase());
+
+ quote! {
+ #method_call.map_err(|err| InjectableError::ResolveFailed {
+ reason: Box::new(err),
+ affected: self_type_name
+ })?.#to_ptr().unwrap()
+ }
+ })
.collect()
}
diff --git a/src/di_container.rs b/src/di_container.rs
index cc2a930..1be570b 100644
--- a/src/di_container.rs
+++ b/src/di_container.rs
@@ -59,7 +59,7 @@ use crate::errors::di_container::{
use crate::interfaces::injectable::Injectable;
use crate::libs::intertrait::cast::{CastBox, CastRc};
use crate::provider::{Providable, SingletonProvider, TransientTypeProvider};
-use crate::ptr::{SingletonPtr, TransientPtr};
+use crate::ptr::{SingletonPtr, SomePtr};
/// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`].
pub struct BindingScopeConfigurator<'di_container, Interface, Implementation>
@@ -177,7 +177,7 @@ where
#[cfg(feature = "factory")]
pub fn to_factory<Args, Return>(
&mut self,
- factory_func: &'static dyn Fn<Args, Output = TransientPtr<Return>>,
+ factory_func: &'static dyn Fn<Args, Output = crate::ptr::TransientPtr<Return>>,
) -> Result<(), BindingBuilderError>
where
Args: 'static,
@@ -212,7 +212,7 @@ where
#[cfg(feature = "factory")]
pub fn to_default_factory<Return>(
&mut self,
- factory_func: &'static dyn Fn<(), Output = TransientPtr<Return>>,
+ factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr<Return>>,
) -> Result<(), BindingBuilderError>
where
Return: 'static + ?Sized,
@@ -260,127 +260,67 @@ impl DIContainer
BindingBuilder::<Interface>::new(self)
}
- /// Returns a new instance of the type bound with `Interface`.
+ /// Returns the type bound with `Interface`.
///
/// # Errors
/// Will return `Err` if:
/// - No binding for `Interface` exists
/// - Resolving the binding for `Interface` fails
/// - Casting the binding for `Interface` fails
- /// - The binding for `Interface` is not transient
- pub fn get<Interface>(&self) -> Result<TransientPtr<Interface>, DIContainerError>
+ pub fn get<Interface>(&self) -> Result<SomePtr<Interface>, DIContainerError>
where
Interface: 'static + ?Sized,
{
- self.get_with_history::<Interface>(Vec::new())
- }
-
- /// Returns the singleton instance bound with `Interface`.
- ///
- /// # Errors
- /// Will return `Err` if:
- /// - No binding for `Interface` exists
- /// - Resolving the binding for `Interface` fails
- /// - Casting the binding for `Interface` fails
- /// - The binding for `Interface` is not a singleton
- pub fn get_singleton<Interface>(
- &self,
- ) -> Result<SingletonPtr<Interface>, DIContainerError>
- where
- Interface: 'static + ?Sized,
- {
- self.get_singleton_with_history(Vec::new())
- }
-
- /// Returns the factory bound with factory type `Interface`.
- ///
- /// *This function is only available if Syrette is built with the "factory" feature.*
- ///
- /// # Errors
- /// Will return `Err` if:
- /// - No binding for `Interface` exists
- /// - Resolving the binding for `Interface` fails
- /// - Casting the binding for `Interface` fails
- /// - The binding for `Interface` is not a factory
- #[cfg(feature = "factory")]
- pub fn get_factory<Interface>(
- &self,
- ) -> Result<crate::ptr::FactoryPtr<Interface>, DIContainerError>
- where
- Interface: 'static + ?Sized,
- {
- let binding_providable = self.get_binding_providable::<Interface>(Vec::new())?;
-
- if let Providable::Factory(binding_factory) = binding_providable {
- return binding_factory
- .cast::<Interface>()
- .map_err(|_| DIContainerError::CastFailed(type_name::<Interface>()));
- }
-
- Err(DIContainerError::WrongBindingType {
- interface: type_name::<Interface>(),
- expected: "factory",
- found: binding_providable.to_string().to_lowercase(),
- })
- }
-
- #[doc(hidden)]
- pub fn get_with_history<Interface>(
- &self,
- dependency_history: Vec<&'static str>,
- ) -> Result<TransientPtr<Interface>, DIContainerError>
- where
- Interface: 'static + ?Sized,
- {
- let binding_providable =
- self.get_binding_providable::<Interface>(dependency_history)?;
-
- if let Providable::Transient(binding_transient) = binding_providable {
- return binding_transient
- .cast::<Interface>()
- .map_err(|_| DIContainerError::CastFailed(type_name::<Interface>()));
- }
-
- #[cfg(feature = "factory")]
- if let Providable::Factory(binding_factory) = binding_providable {
- use crate::interfaces::factory::IFactory;
-
- let factory = binding_factory
- .cast::<dyn IFactory<(), Interface>>()
- .map_err(|_| DIContainerError::CastFailed(type_name::<Interface>()))?;
-
- return Ok(factory());
- }
-
- Err(DIContainerError::WrongBindingType {
- interface: type_name::<Interface>(),
- expected: "transient",
- found: binding_providable.to_string().to_lowercase(),
- })
+ self.get_bound::<Interface>(Vec::new())
}
#[doc(hidden)]
- pub fn get_singleton_with_history<Interface>(
+ pub fn get_bound<Interface>(
&self,
dependency_history: Vec<&'static str>,
- ) -> Result<SingletonPtr<Interface>, DIContainerError>
+ ) -> Result<SomePtr<Interface>, DIContainerError>
where
Interface: 'static + ?Sized,
{
let binding_providable =
self.get_binding_providable::<Interface>(dependency_history)?;
- if let Providable::Singleton(binding_singleton) = binding_providable {
- return binding_singleton
- .cast::<Interface>()
- .map_err(|_| DIContainerError::CastFailed(type_name::<Interface>()));
+ match binding_providable {
+ Providable::Transient(transient_binding) => Ok(SomePtr::Transient(
+ transient_binding.cast::<Interface>().map_err(|_| {
+ DIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ )),
+ Providable::Singleton(singleton_binding) => Ok(SomePtr::Singleton(
+ singleton_binding.cast::<Interface>().map_err(|_| {
+ DIContainerError::CastFailed(type_name::<Interface>())
+ })?,
+ )),
+ #[cfg(feature = "factory")]
+ Providable::Factory(factory_binding) => {
+ match factory_binding.clone().cast::<Interface>() {
+ Ok(factory) => Ok(SomePtr::Factory(factory)),
+ Err(_err) => {
+ use crate::interfaces::factory::IFactory;
+
+ let default_factory = factory_binding
+ .cast::<dyn IFactory<(), Interface>>()
+ .map_err(|_| {
+ DIContainerError::CastFailed(type_name::<Interface>())
+ })?;
+
+ Ok(SomePtr::Transient(default_factory()))
+ }
+ }
+ }
+ #[cfg(not(feature = "factory"))]
+ Providable::Factory(_) => {
+ return Err(DIContainerError::CantHandleFactoryBinding(type_name::<
+ Interface,
+ >(
+ )));
+ }
}
-
- Err(DIContainerError::WrongBindingType {
- interface: type_name::<Interface>(),
- expected: "singleton",
- found: binding_providable.to_string().to_lowercase(),
- })
}
fn get_binding_providable<Interface>(
@@ -545,7 +485,7 @@ mod tests
}
#[test]
- fn can_bind_to() -> Result<(), BindingBuilderError>
+ fn can_bind_to() -> Result<(), Box<dyn Error>>
{
let mut di_container: DIContainer = DIContainer::new();
@@ -579,7 +519,7 @@ mod tests
#[test]
#[cfg(feature = "factory")]
- fn can_bind_to_factory() -> Result<(), BindingBuilderError>
+ fn can_bind_to_factory() -> Result<(), Box<dyn Error>>
{
type IUserManagerFactory =
dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>;
@@ -601,7 +541,7 @@ mod tests
}
#[test]
- fn can_get() -> Result<(), DIContainerError>
+ fn can_get() -> Result<(), Box<dyn Error>>
{
mock! {
Provider {}
@@ -636,7 +576,7 @@ mod tests
}
#[test]
- fn can_get_singleton() -> Result<(), DIContainerError>
+ fn can_get_singleton() -> Result<(), Box<dyn Error>>
{
mock! {
Provider {}
@@ -667,11 +607,12 @@ mod tests
.bindings
.set::<dyn subjects::INumber>(Box::new(mock_provider));
- let first_number_rc = di_container.get_singleton::<dyn subjects::INumber>()?;
+ let first_number_rc = di_container.get::<dyn subjects::INumber>()?.singleton()?;
assert_eq!(first_number_rc.get(), 2820);
- let second_number_rc = di_container.get_singleton::<dyn subjects::INumber>()?;
+ let second_number_rc =
+ di_container.get::<dyn subjects::INumber>()?.singleton()?;
assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref());
@@ -680,7 +621,7 @@ mod tests
#[test]
#[cfg(feature = "factory")]
- fn can_get_factory() -> Result<(), DIContainerError>
+ fn can_get_factory() -> Result<(), Box<dyn Error>>
{
trait IUserManager
{
@@ -756,7 +697,7 @@ mod tests
.bindings
.set::<IUserManagerFactory>(Box::new(mock_provider));
- di_container.get_factory::<IUserManagerFactory>()?;
+ di_container.get::<IUserManagerFactory>()?;
Ok(())
}
diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs
index 98c2be4..4a74b5d 100644
--- a/src/errors/di_container.rs
+++ b/src/errors/di_container.rs
@@ -14,20 +14,6 @@ pub enum DIContainerError
#[error("Unable to cast binding for interface '{0}'")]
CastFailed(&'static str),
- /// Wrong binding type.
- #[error("Wrong binding type for interface '{interface}'. Expected a {expected}. Found a {found}")]
- WrongBindingType
- {
- /// The affected bound interface.
- interface: &'static str,
-
- /// The expected binding type.
- expected: &'static str,
-
- /// The found binding type.
- found: String,
- },
-
/// Failed to resolve a binding for a interface.
#[error("Failed to resolve binding for interface '{interface}'")]
BindingResolveFailed
@@ -43,6 +29,10 @@ pub enum DIContainerError
/// No binding exists for a interface.
#[error("No binding exists for interface '{0}'")]
BindingNotFound(&'static str),
+
+ /// The binding for a interface is a factory but the factory feature isn't enabled.
+ #[error("The binding for interface '{0}' is a factory but the factory feature isn't enabled")]
+ CantHandleFactoryBinding(&'static str),
}
/// Error type for [`BindingBuilder`].
@@ -58,7 +48,7 @@ pub enum BindingBuilderError
/// Error type for [`BindingScopeConfigurator`].
///
-/// [`BindingBuilder`]: crate::di_container::BindingScopeConfigurator
+/// [`BindingScopeConfigurator`]: crate::di_container::BindingScopeConfigurator
#[derive(thiserror::Error, Debug)]
pub enum BindingScopeConfiguratorError
{
diff --git a/src/errors/mod.rs b/src/errors/mod.rs
index 5f628d6..7d66ddf 100644
--- a/src/errors/mod.rs
+++ b/src/errors/mod.rs
@@ -2,3 +2,4 @@
pub mod di_container;
pub mod injectable;
+pub mod ptr;
diff --git a/src/errors/ptr.rs b/src/errors/ptr.rs
new file mode 100644
index 0000000..e0c3d05
--- /dev/null
+++ b/src/errors/ptr.rs
@@ -0,0 +1,19 @@
+//! Smart pointer alias errors.
+
+/// Error type for [`SomePtr`].
+///
+/// [`SomePtr`]: crate::ptr::SomePtr
+#[derive(thiserror::Error, Debug)]
+pub enum SomePtrError
+{
+ /// Tried to get as a wrong smart pointer type.
+ #[error("Wrong smart pointer type. Expected {expected}, found {found}")]
+ WrongPtrType
+ {
+ /// The expected smart pointer type.
+ expected: &'static str,
+
+ /// The found smart pointer type.
+ found: &'static str,
+ },
+}
diff --git a/src/ptr.rs b/src/ptr.rs
index 082edf2..08c3788 100644
--- a/src/ptr.rs
+++ b/src/ptr.rs
@@ -3,11 +3,65 @@
//! Smart pointer type aliases.
use std::rc::Rc;
-/// A smart pointer unique to the holder.
+use paste::paste;
+
+use crate::errors::ptr::SomePtrError;
+
+/// A smart pointer for a interface in the transient scope.
pub type TransientPtr<Interface> = Box<Interface>;
-/// A smart pointer to a shared resource.
+/// A smart pointer to a interface in the singleton scope.
pub type SingletonPtr<Interface> = Rc<Interface>;
/// A smart pointer to a factory.
pub type FactoryPtr<FactoryInterface> = Rc<FactoryInterface>;
+
+/// Some smart pointer.
+#[derive(strum_macros::IntoStaticStr)]
+pub enum SomePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ /// A smart pointer to a interface in the transient scope.
+ Transient(TransientPtr<Interface>),
+
+ /// A smart pointer to a interface in the singleton scope.
+ Singleton(SingletonPtr<Interface>),
+
+ /// A smart pointer to a factory.
+ Factory(FactoryPtr<Interface>),
+}
+
+macro_rules! create_as_variant_fn {
+ ($variant: ident) => {
+ paste! {
+ #[doc =
+ "Returns as " [<$variant:lower>] ".\n"
+ "\n"
+ "# Errors\n"
+ "Will return Err if it's not a " [<$variant:lower>] "."
+ ]
+ pub fn [<$variant:lower>](self) -> Result<[<$variant Ptr>]<Interface>, SomePtrError>
+ {
+ if let SomePtr::$variant(ptr) = self {
+ return Ok(ptr);
+ }
+
+
+ Err(SomePtrError::WrongPtrType {
+ expected: stringify!($variant),
+ found: self.into()
+ })
+ }
+ }
+ };
+}
+
+impl<Interface> SomePtr<Interface>
+where
+ Interface: 'static + ?Sized,
+{
+ create_as_variant_fn!(Transient);
+ create_as_variant_fn!(Singleton);
+ create_as_variant_fn!(Factory);
+}