From 1c46b68581213ca8ae6200daa32f626b5389b4b0 Mon Sep 17 00:00:00 2001 From: HampusM Date: Thu, 25 Aug 2022 20:21:49 +0200 Subject: refactor!: make DI container have single get function BREAKING CHANGE: The DI container get_singleton & get_factory functions have been replaced by the get function now returning a enum --- Cargo.toml | 1 + examples/basic/main.rs | 4 +- examples/factory/bootstrap.rs | 18 ++--- examples/factory/main.rs | 10 ++- examples/generics/main.rs | 10 ++- examples/unbound/main.rs | 4 +- examples/with-3rd-party/main.rs | 2 +- macros/src/injectable_impl.rs | 83 ++++++++++---------- src/di_container.rs | 163 +++++++++++++--------------------------- src/errors/di_container.rs | 20 ++--- src/errors/mod.rs | 1 + src/errors/ptr.rs | 19 +++++ src/ptr.rs | 58 +++++++++++++- 13 files changed, 204 insertions(+), 189 deletions(-) create mode 100644 src/errors/ptr.rs diff --git a/Cargo.toml b/Cargo.toml index f1c6885..8d26077 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ ahash = "0.7.6" thiserror = "1.0.32" strum = "0.24.1" strum_macros = "0.24.3" +paste = "1.0.8" [dev_dependencies] mockall = "0.11.1" diff --git a/examples/basic/main.rs b/examples/basic/main.rs index 72f07c2..dbc9215 100644 --- a/examples/basic/main.rs +++ b/examples/basic/main.rs @@ -18,11 +18,11 @@ fn main() -> Result<(), Box> let di_container = bootstrap()?; - let dog = di_container.get_singleton::()?; + let dog = di_container.get::()?.singleton()?; dog.woof(); - let human = di_container.get::()?; + let human = di_container.get::()?.transient()?; human.make_pets_make_sounds(); diff --git a/examples/factory/bootstrap.rs b/examples/factory/bootstrap.rs index b752764..ad8c4d3 100644 --- a/examples/factory/bootstrap.rs +++ b/examples/factory/bootstrap.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use syrette::ptr::TransientPtr; use syrette::DIContainer; @@ -9,24 +11,22 @@ use crate::interfaces::user_manager::IUserManager; use crate::user::User; use crate::user_manager::UserManager; -pub fn bootstrap() -> DIContainer +pub fn bootstrap() -> Result> { let mut di_container: DIContainer = DIContainer::new(); di_container .bind::() - .to::() - .unwrap(); + .to::()?; - di_container - .bind::() - .to_factory(&|name, date_of_birth, password| { + di_container.bind::().to_factory( + &|name, date_of_birth, password| { let user: TransientPtr = TransientPtr::new(User::new(name, date_of_birth, password)); user - }) - .unwrap(); + }, + )?; - di_container + Ok(di_container) } diff --git a/examples/factory/main.rs b/examples/factory/main.rs index bf3d43b..0f1a97b 100644 --- a/examples/factory/main.rs +++ b/examples/factory/main.rs @@ -7,21 +7,25 @@ mod interfaces; mod user; mod user_manager; +use std::error::Error; + use bootstrap::bootstrap; use crate::interfaces::user_manager::IUserManager; -fn main() +fn main() -> Result<(), Box> { println!("Hello, world!"); - let di_container = bootstrap(); + let di_container = bootstrap()?; - let mut user_manager = di_container.get::().unwrap(); + let mut user_manager = di_container.get::()?.transient()?; user_manager.fill_with_users(); println!("Printing user information"); user_manager.print_users(); + + Ok(()) } diff --git a/examples/generics/main.rs b/examples/generics/main.rs index 9442641..f491aa0 100644 --- a/examples/generics/main.rs +++ b/examples/generics/main.rs @@ -2,18 +2,22 @@ mod bootstrap; mod interfaces; mod printer; +use std::error::Error; + use bootstrap::bootstrap; use interfaces::printer::IPrinter; -fn main() +fn main() -> Result<(), Box> { let di_container = bootstrap(); - let string_printer = di_container.get::>().unwrap(); + let string_printer = di_container.get::>()?.transient()?; string_printer.print("Hello there".to_string()); - let int_printer = di_container.get::>().unwrap(); + let int_printer = di_container.get::>()?.transient()?; int_printer.print(2782028); + + Ok(()) } diff --git a/examples/unbound/main.rs b/examples/unbound/main.rs index 031a691..e9a8feb 100644 --- a/examples/unbound/main.rs +++ b/examples/unbound/main.rs @@ -19,11 +19,11 @@ fn main() -> Result<(), Box> let di_container = bootstrap()?; - let dog = di_container.get_singleton::()?; + let dog = di_container.get::()?.singleton()?; dog.woof(); - let human = di_container.get::()?; + let human = di_container.get::()?.transient()?; human.make_pets_make_sounds(); diff --git a/examples/with-3rd-party/main.rs b/examples/with-3rd-party/main.rs index dd4c21f..e48c78f 100644 --- a/examples/with-3rd-party/main.rs +++ b/examples/with-3rd-party/main.rs @@ -17,7 +17,7 @@ fn main() -> Result<(), Box> let di_container = bootstrap()?; - let ninja = di_container.get::()?; + let ninja = di_container.get::()?.transient()?; ninja.throw_shuriken(); diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index 62bf057..d74acb3 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -103,12 +103,7 @@ impl InjectableImpl #maybe_prevent_circular_deps return Ok(syrette::ptr::TransientPtr::new(Self::new( - #(#get_dep_method_calls - .map_err(|err| InjectableError::ResolveFailed { - reason: Box::new(err), - affected: self_type_name - })? - ),* + #(#get_dep_method_calls),* ))); } } @@ -117,53 +112,59 @@ impl InjectableImpl fn create_get_dep_method_calls( dependency_types: &[DependencyType], - ) -> Vec + ) -> Vec { dependency_types .iter() .filter_map(|dep_type| match &dep_type.interface { - Type::TraitObject(dep_type_trait) => parse_str::( - format!( - "{}.get_{}::<{}>({}.clone())", - DI_CONTAINER_VAR_NAME, - if dep_type.ptr == "SingletonPtr" { - "singleton_with_history" - } else { - "with_history" - }, - dep_type_trait.to_token_stream(), - DEPENDENCY_HISTORY_VAR_NAME + Type::TraitObject(dep_type_trait) => { + let method_call = parse_str::( + format!( + "{}.get_bound::<{}>({}.clone())", + DI_CONTAINER_VAR_NAME, + dep_type_trait.to_token_stream(), + DEPENDENCY_HISTORY_VAR_NAME + ) + .as_str(), ) - .as_str(), - ) - .ok(), + .ok()?; + + Some((method_call, dep_type)) + + /* + */ + } Type::Path(dep_type_path) => { let dep_type_path_str = Self::path_to_string(&dep_type_path.path); - if dep_type_path_str.ends_with("Factory") { - parse_str( - format!( - "{}.get_factory::<{}>()", - DI_CONTAINER_VAR_NAME, dep_type_path_str - ) - .as_str(), - ) - .ok() - } else { - parse_str( - format!( - "{}.get_with_history::<{}>({}.clone())", - DI_CONTAINER_VAR_NAME, - dep_type_path_str, - DEPENDENCY_HISTORY_VAR_NAME - ) - .as_str(), + let method_call: ExprMethodCall = parse_str( + format!( + "{}.get_bound::<{}>({}.clone())", + DI_CONTAINER_VAR_NAME, + dep_type_path_str, + DEPENDENCY_HISTORY_VAR_NAME ) - .ok() - } + .as_str(), + ) + .ok()?; + + Some((method_call, dep_type)) } &_ => None, }) + .map(|(method_call, dep_type)| { + let ptr_name = dep_type.ptr.to_string(); + + let to_ptr = + format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase()); + + quote! { + #method_call.map_err(|err| InjectableError::ResolveFailed { + reason: Box::new(err), + affected: self_type_name + })?.#to_ptr().unwrap() + } + }) .collect() } diff --git a/src/di_container.rs b/src/di_container.rs index cc2a930..1be570b 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -59,7 +59,7 @@ use crate::errors::di_container::{ use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast::{CastBox, CastRc}; use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; -use crate::ptr::{SingletonPtr, TransientPtr}; +use crate::ptr::{SingletonPtr, SomePtr}; /// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`]. pub struct BindingScopeConfigurator<'di_container, Interface, Implementation> @@ -177,7 +177,7 @@ where #[cfg(feature = "factory")] pub fn to_factory( &mut self, - factory_func: &'static dyn Fn>, + factory_func: &'static dyn Fn>, ) -> Result<(), BindingBuilderError> where Args: 'static, @@ -212,7 +212,7 @@ where #[cfg(feature = "factory")] pub fn to_default_factory( &mut self, - factory_func: &'static dyn Fn<(), Output = TransientPtr>, + factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr>, ) -> Result<(), BindingBuilderError> where Return: 'static + ?Sized, @@ -260,127 +260,67 @@ impl DIContainer BindingBuilder::::new(self) } - /// Returns a new instance of the type bound with `Interface`. + /// Returns the type bound with `Interface`. /// /// # Errors /// Will return `Err` if: /// - No binding for `Interface` exists /// - Resolving the binding for `Interface` fails /// - Casting the binding for `Interface` fails - /// - The binding for `Interface` is not transient - pub fn get(&self) -> Result, DIContainerError> + pub fn get(&self) -> Result, DIContainerError> where Interface: 'static + ?Sized, { - self.get_with_history::(Vec::new()) - } - - /// Returns the singleton instance bound with `Interface`. - /// - /// # Errors - /// Will return `Err` if: - /// - No binding for `Interface` exists - /// - Resolving the binding for `Interface` fails - /// - Casting the binding for `Interface` fails - /// - The binding for `Interface` is not a singleton - pub fn get_singleton( - &self, - ) -> Result, DIContainerError> - where - Interface: 'static + ?Sized, - { - self.get_singleton_with_history(Vec::new()) - } - - /// Returns the factory bound with factory type `Interface`. - /// - /// *This function is only available if Syrette is built with the "factory" feature.* - /// - /// # Errors - /// Will return `Err` if: - /// - No binding for `Interface` exists - /// - Resolving the binding for `Interface` fails - /// - Casting the binding for `Interface` fails - /// - The binding for `Interface` is not a factory - #[cfg(feature = "factory")] - pub fn get_factory( - &self, - ) -> Result, DIContainerError> - where - Interface: 'static + ?Sized, - { - let binding_providable = self.get_binding_providable::(Vec::new())?; - - if let Providable::Factory(binding_factory) = binding_providable { - return binding_factory - .cast::() - .map_err(|_| DIContainerError::CastFailed(type_name::())); - } - - Err(DIContainerError::WrongBindingType { - interface: type_name::(), - expected: "factory", - found: binding_providable.to_string().to_lowercase(), - }) - } - - #[doc(hidden)] - pub fn get_with_history( - &self, - dependency_history: Vec<&'static str>, - ) -> Result, DIContainerError> - where - Interface: 'static + ?Sized, - { - let binding_providable = - self.get_binding_providable::(dependency_history)?; - - if let Providable::Transient(binding_transient) = binding_providable { - return binding_transient - .cast::() - .map_err(|_| DIContainerError::CastFailed(type_name::())); - } - - #[cfg(feature = "factory")] - if let Providable::Factory(binding_factory) = binding_providable { - use crate::interfaces::factory::IFactory; - - let factory = binding_factory - .cast::>() - .map_err(|_| DIContainerError::CastFailed(type_name::()))?; - - return Ok(factory()); - } - - Err(DIContainerError::WrongBindingType { - interface: type_name::(), - expected: "transient", - found: binding_providable.to_string().to_lowercase(), - }) + self.get_bound::(Vec::new()) } #[doc(hidden)] - pub fn get_singleton_with_history( + pub fn get_bound( &self, dependency_history: Vec<&'static str>, - ) -> Result, DIContainerError> + ) -> Result, DIContainerError> where Interface: 'static + ?Sized, { let binding_providable = self.get_binding_providable::(dependency_history)?; - if let Providable::Singleton(binding_singleton) = binding_providable { - return binding_singleton - .cast::() - .map_err(|_| DIContainerError::CastFailed(type_name::())); + match binding_providable { + Providable::Transient(transient_binding) => Ok(SomePtr::Transient( + transient_binding.cast::().map_err(|_| { + DIContainerError::CastFailed(type_name::()) + })?, + )), + Providable::Singleton(singleton_binding) => Ok(SomePtr::Singleton( + singleton_binding.cast::().map_err(|_| { + DIContainerError::CastFailed(type_name::()) + })?, + )), + #[cfg(feature = "factory")] + Providable::Factory(factory_binding) => { + match factory_binding.clone().cast::() { + Ok(factory) => Ok(SomePtr::Factory(factory)), + Err(_err) => { + use crate::interfaces::factory::IFactory; + + let default_factory = factory_binding + .cast::>() + .map_err(|_| { + DIContainerError::CastFailed(type_name::()) + })?; + + Ok(SomePtr::Transient(default_factory())) + } + } + } + #[cfg(not(feature = "factory"))] + Providable::Factory(_) => { + return Err(DIContainerError::CantHandleFactoryBinding(type_name::< + Interface, + >( + ))); + } } - - Err(DIContainerError::WrongBindingType { - interface: type_name::(), - expected: "singleton", - found: binding_providable.to_string().to_lowercase(), - }) } fn get_binding_providable( @@ -545,7 +485,7 @@ mod tests } #[test] - fn can_bind_to() -> Result<(), BindingBuilderError> + fn can_bind_to() -> Result<(), Box> { let mut di_container: DIContainer = DIContainer::new(); @@ -579,7 +519,7 @@ mod tests #[test] #[cfg(feature = "factory")] - fn can_bind_to_factory() -> Result<(), BindingBuilderError> + fn can_bind_to_factory() -> Result<(), Box> { type IUserManagerFactory = dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; @@ -601,7 +541,7 @@ mod tests } #[test] - fn can_get() -> Result<(), DIContainerError> + fn can_get() -> Result<(), Box> { mock! { Provider {} @@ -636,7 +576,7 @@ mod tests } #[test] - fn can_get_singleton() -> Result<(), DIContainerError> + fn can_get_singleton() -> Result<(), Box> { mock! { Provider {} @@ -667,11 +607,12 @@ mod tests .bindings .set::(Box::new(mock_provider)); - let first_number_rc = di_container.get_singleton::()?; + let first_number_rc = di_container.get::()?.singleton()?; assert_eq!(first_number_rc.get(), 2820); - let second_number_rc = di_container.get_singleton::()?; + let second_number_rc = + di_container.get::()?.singleton()?; assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); @@ -680,7 +621,7 @@ mod tests #[test] #[cfg(feature = "factory")] - fn can_get_factory() -> Result<(), DIContainerError> + fn can_get_factory() -> Result<(), Box> { trait IUserManager { @@ -756,7 +697,7 @@ mod tests .bindings .set::(Box::new(mock_provider)); - di_container.get_factory::()?; + di_container.get::()?; Ok(()) } diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index 98c2be4..4a74b5d 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -14,20 +14,6 @@ pub enum DIContainerError #[error("Unable to cast binding for interface '{0}'")] CastFailed(&'static str), - /// Wrong binding type. - #[error("Wrong binding type for interface '{interface}'. Expected a {expected}. Found a {found}")] - WrongBindingType - { - /// The affected bound interface. - interface: &'static str, - - /// The expected binding type. - expected: &'static str, - - /// The found binding type. - found: String, - }, - /// Failed to resolve a binding for a interface. #[error("Failed to resolve binding for interface '{interface}'")] BindingResolveFailed @@ -43,6 +29,10 @@ pub enum DIContainerError /// No binding exists for a interface. #[error("No binding exists for interface '{0}'")] BindingNotFound(&'static str), + + /// The binding for a interface is a factory but the factory feature isn't enabled. + #[error("The binding for interface '{0}' is a factory but the factory feature isn't enabled")] + CantHandleFactoryBinding(&'static str), } /// Error type for [`BindingBuilder`]. @@ -58,7 +48,7 @@ pub enum BindingBuilderError /// Error type for [`BindingScopeConfigurator`]. /// -/// [`BindingBuilder`]: crate::di_container::BindingScopeConfigurator +/// [`BindingScopeConfigurator`]: crate::di_container::BindingScopeConfigurator #[derive(thiserror::Error, Debug)] pub enum BindingScopeConfiguratorError { diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 5f628d6..7d66ddf 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -2,3 +2,4 @@ pub mod di_container; pub mod injectable; +pub mod ptr; diff --git a/src/errors/ptr.rs b/src/errors/ptr.rs new file mode 100644 index 0000000..e0c3d05 --- /dev/null +++ b/src/errors/ptr.rs @@ -0,0 +1,19 @@ +//! Smart pointer alias errors. + +/// Error type for [`SomePtr`]. +/// +/// [`SomePtr`]: crate::ptr::SomePtr +#[derive(thiserror::Error, Debug)] +pub enum SomePtrError +{ + /// Tried to get as a wrong smart pointer type. + #[error("Wrong smart pointer type. Expected {expected}, found {found}")] + WrongPtrType + { + /// The expected smart pointer type. + expected: &'static str, + + /// The found smart pointer type. + found: &'static str, + }, +} diff --git a/src/ptr.rs b/src/ptr.rs index 082edf2..08c3788 100644 --- a/src/ptr.rs +++ b/src/ptr.rs @@ -3,11 +3,65 @@ //! Smart pointer type aliases. use std::rc::Rc; -/// A smart pointer unique to the holder. +use paste::paste; + +use crate::errors::ptr::SomePtrError; + +/// A smart pointer for a interface in the transient scope. pub type TransientPtr = Box; -/// A smart pointer to a shared resource. +/// A smart pointer to a interface in the singleton scope. pub type SingletonPtr = Rc; /// A smart pointer to a factory. pub type FactoryPtr = Rc; + +/// Some smart pointer. +#[derive(strum_macros::IntoStaticStr)] +pub enum SomePtr +where + Interface: 'static + ?Sized, +{ + /// A smart pointer to a interface in the transient scope. + Transient(TransientPtr), + + /// A smart pointer to a interface in the singleton scope. + Singleton(SingletonPtr), + + /// A smart pointer to a factory. + Factory(FactoryPtr), +} + +macro_rules! create_as_variant_fn { + ($variant: ident) => { + paste! { + #[doc = + "Returns as " [<$variant:lower>] ".\n" + "\n" + "# Errors\n" + "Will return Err if it's not a " [<$variant:lower>] "." + ] + pub fn [<$variant:lower>](self) -> Result<[<$variant Ptr>], SomePtrError> + { + if let SomePtr::$variant(ptr) = self { + return Ok(ptr); + } + + + Err(SomePtrError::WrongPtrType { + expected: stringify!($variant), + found: self.into() + }) + } + } + }; +} + +impl SomePtr +where + Interface: 'static + ?Sized, +{ + create_as_variant_fn!(Transient); + create_as_variant_fn!(Singleton); + create_as_variant_fn!(Factory); +} -- cgit v1.2.3-18-g5258