From e0f90a8e384615c79d7d51c66d19294d75e79391 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 27 Aug 2022 23:41:41 +0200 Subject: feat: implement named bindings --- Cargo.toml | 1 + README.md | 1 + examples/named/bootstrap.rs | 29 +++ examples/named/interfaces/mod.rs | 2 + examples/named/interfaces/ninja.rs | 4 + examples/named/interfaces/weapon.rs | 4 + examples/named/katana.rs | 22 ++ examples/named/main.rs | 25 +++ examples/named/ninja.rs | 45 ++++ examples/named/shuriken.rs | 22 ++ macros/src/dependency.rs | 81 ++++++++ macros/src/dependency_type.rs | 40 ---- macros/src/injectable_impl.rs | 200 +++++++----------- macros/src/lib.rs | 107 +++++++++- macros/src/named_attr_input.rs | 21 ++ macros/src/util/item_impl.rs | 12 +- macros/src/util/mod.rs | 1 + macros/src/util/syn_path.rs | 22 ++ src/di_container.rs | 404 +++++++++++++++++++++++++++++++++--- src/di_container_binding_map.rs | 59 +++++- src/errors/di_container.rs | 27 ++- 21 files changed, 905 insertions(+), 224 deletions(-) create mode 100644 examples/named/bootstrap.rs create mode 100644 examples/named/interfaces/mod.rs create mode 100644 examples/named/interfaces/ninja.rs create mode 100644 examples/named/interfaces/weapon.rs create mode 100644 examples/named/katana.rs create mode 100644 examples/named/main.rs create mode 100644 examples/named/ninja.rs create mode 100644 examples/named/shuriken.rs create mode 100644 macros/src/dependency.rs delete mode 100644 macros/src/dependency_type.rs create mode 100644 macros/src/named_attr_input.rs create mode 100644 macros/src/util/syn_path.rs diff --git a/Cargo.toml b/Cargo.toml index 8d26077..b3aa027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ paste = "1.0.8" [dev_dependencies] mockall = "0.11.1" +anyhow = "1.0.62" third-party-lib = { path = "./examples/with-3rd-party/third-party-lib" } [workspace] diff --git a/README.md b/README.md index e07f30c..d61614b 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ From the [syrette Wikipedia article](https://en.wikipedia.org/wiki/Syrette). - Supports generic implementations & generic interface traits - Binding singletons - Injection of third-party structs & traits +- Named bindings ## Optional features - `factory`. Binding factories (Rust nightly required) diff --git a/examples/named/bootstrap.rs b/examples/named/bootstrap.rs new file mode 100644 index 0000000..b5fa39d --- /dev/null +++ b/examples/named/bootstrap.rs @@ -0,0 +1,29 @@ +use anyhow::Result; +use syrette::DIContainer; + +use crate::interfaces::ninja::INinja; +use crate::interfaces::weapon::IWeapon; +use crate::katana::Katana; +use crate::ninja::Ninja; +use crate::shuriken::Shuriken; + +pub fn bootstrap() -> Result +{ + let mut di_container: DIContainer = DIContainer::new(); + + di_container + .bind::() + .to::()? + .in_transient_scope() + .when_named("strong")?; + + di_container + .bind::() + .to::()? + .in_transient_scope() + .when_named("weak")?; + + di_container.bind::().to::()?; + + Ok(di_container) +} diff --git a/examples/named/interfaces/mod.rs b/examples/named/interfaces/mod.rs new file mode 100644 index 0000000..6a0108d --- /dev/null +++ b/examples/named/interfaces/mod.rs @@ -0,0 +1,2 @@ +pub mod ninja; +pub mod weapon; diff --git a/examples/named/interfaces/ninja.rs b/examples/named/interfaces/ninja.rs new file mode 100644 index 0000000..2d378c8 --- /dev/null +++ b/examples/named/interfaces/ninja.rs @@ -0,0 +1,4 @@ +pub trait INinja +{ + fn use_weapons(&self); +} diff --git a/examples/named/interfaces/weapon.rs b/examples/named/interfaces/weapon.rs new file mode 100644 index 0000000..7848a0f --- /dev/null +++ b/examples/named/interfaces/weapon.rs @@ -0,0 +1,4 @@ +pub trait IWeapon +{ + fn use_it(&self); +} diff --git a/examples/named/katana.rs b/examples/named/katana.rs new file mode 100644 index 0000000..a03af6d --- /dev/null +++ b/examples/named/katana.rs @@ -0,0 +1,22 @@ +use syrette::injectable; + +use crate::interfaces::weapon::IWeapon; + +pub struct Katana {} + +#[injectable(IWeapon)] +impl Katana +{ + pub fn new() -> Self + { + Self {} + } +} + +impl IWeapon for Katana +{ + fn use_it(&self) + { + println!("Used katana!"); + } +} diff --git a/examples/named/main.rs b/examples/named/main.rs new file mode 100644 index 0000000..5411a12 --- /dev/null +++ b/examples/named/main.rs @@ -0,0 +1,25 @@ +#![deny(clippy::all)] +#![deny(clippy::pedantic)] +#![allow(clippy::module_name_repetitions)] + +mod bootstrap; +mod interfaces; +mod katana; +mod ninja; +mod shuriken; + +use anyhow::Result; + +use crate::bootstrap::bootstrap; +use crate::interfaces::ninja::INinja; + +fn main() -> Result<()> +{ + let di_container = bootstrap()?; + + let ninja = di_container.get::()?.transient()?; + + ninja.use_weapons(); + + Ok(()) +} diff --git a/examples/named/ninja.rs b/examples/named/ninja.rs new file mode 100644 index 0000000..2069f14 --- /dev/null +++ b/examples/named/ninja.rs @@ -0,0 +1,45 @@ +use syrette::injectable; +use syrette::ptr::TransientPtr; + +use crate::interfaces::ninja::INinja; +use crate::interfaces::weapon::IWeapon; + +pub struct Ninja +{ + strong_weapon: TransientPtr, + weak_weapon: TransientPtr, +} + +#[injectable(INinja)] +impl Ninja +{ + pub fn new( + #[rustfmt::skip] // Prevent rustfmt from turning this into a single line + #[syrette::named("strong")] + strong_weapon: TransientPtr, + + #[rustfmt::skip] // Prevent rustfmt from turning this into a single line + #[named("weak")] + weak_weapon: TransientPtr, + ) -> Self + { + Self { + strong_weapon, + weak_weapon, + } + } +} + +impl INinja for Ninja +{ + fn use_weapons(&self) + { + println!("Ninja is using his strong weapon!"); + + self.strong_weapon.use_it(); + + println!("Ninja is using his weak weapon!"); + + self.weak_weapon.use_it(); + } +} diff --git a/examples/named/shuriken.rs b/examples/named/shuriken.rs new file mode 100644 index 0000000..c50aeac --- /dev/null +++ b/examples/named/shuriken.rs @@ -0,0 +1,22 @@ +use syrette::injectable; + +use crate::interfaces::weapon::IWeapon; + +pub struct Shuriken {} + +#[injectable(IWeapon)] +impl Shuriken +{ + pub fn new() -> Self + { + Self {} + } +} + +impl IWeapon for Shuriken +{ + fn use_it(&self) + { + println!("Used shuriken!"); + } +} diff --git a/macros/src/dependency.rs b/macros/src/dependency.rs new file mode 100644 index 0000000..d20af90 --- /dev/null +++ b/macros/src/dependency.rs @@ -0,0 +1,81 @@ +use std::error::Error; + +use proc_macro2::Ident; +use syn::{parse2, FnArg, GenericArgument, LitStr, PathArguments, Type}; + +use crate::named_attr_input::NamedAttrInput; +use crate::util::syn_path::syn_path_to_string; + +pub struct Dependency +{ + pub interface: Type, + pub ptr: Ident, + pub name: Option, +} + +impl Dependency +{ + pub fn build(new_method_arg: &FnArg) -> Result> + { + let typed_new_method_arg = match new_method_arg { + FnArg::Typed(typed_arg) => Ok(typed_arg), + FnArg::Receiver(_) => Err("Unexpected self argument in 'new' method"), + }?; + + let ptr_type_path = match typed_new_method_arg.ty.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => Ok(arg_type_path), + &_ => Err("Unexpected reference to non-path type"), + }, + &_ => Err("Expected a path or a reference type"), + }?; + + let ptr_path_segment = ptr_type_path.path.segments.last().map_or_else( + || Err("Expected pointer type path to have a last segment"), + Ok, + )?; + + let ptr = ptr_path_segment.ident.clone(); + + let ptr_path_generic_args = &match &ptr_path_segment.arguments { + PathArguments::AngleBracketed(generic_args) => Ok(generic_args), + &_ => Err("Expected pointer type to have a generic type argument"), + }? + .args; + + let interface = if let Some(GenericArgument::Type(interface)) = + ptr_path_generic_args.first() + { + Ok(interface.clone()) + } else { + Err("Expected pointer type to have a generic type argument") + }?; + + let arg_attrs = &typed_new_method_arg.attrs; + + let opt_named_attr = arg_attrs.iter().find(|attr| { + attr.path.get_ident().map_or_else( + || false, + |attr_ident| attr_ident.to_string().as_str() == "named", + ) || syn_path_to_string(&attr.path) == "syrette::named" + }); + + let opt_named_attr_tokens = opt_named_attr.map(|attr| &attr.tokens); + + let opt_named_attr_input = + if let Some(named_attr_tokens) = opt_named_attr_tokens { + Some(parse2::(named_attr_tokens.clone()).map_err( + |err| format!("Invalid input for 'named' attribute. {}", err), + )?) + } else { + None + }; + + Ok(Self { + interface, + ptr, + name: opt_named_attr_input.map(|named_attr_input| named_attr_input.name), + }) + } +} diff --git a/macros/src/dependency_type.rs b/macros/src/dependency_type.rs deleted file mode 100644 index 35f810e..0000000 --- a/macros/src/dependency_type.rs +++ /dev/null @@ -1,40 +0,0 @@ -use proc_macro2::Ident; -use syn::{GenericArgument, PathArguments, Type, TypePath}; - -pub struct DependencyType -{ - pub interface: Type, - pub ptr: Ident, -} - -impl DependencyType -{ - pub fn from_type_path(type_path: &TypePath) -> Option - { - // Assume the type path has a last segment. - let last_path_segment = type_path.path.segments.last().unwrap(); - - let ptr = &last_path_segment.ident; - - match &last_path_segment.arguments { - PathArguments::AngleBracketed(angle_bracketed_generic_args) => { - let generic_args = &angle_bracketed_generic_args.args; - - let opt_first_generic_arg = generic_args.first(); - - // Assume a first generic argument exists because TransientPtr, - // SingletonPtr and FactoryPtr requires one - let first_generic_arg = opt_first_generic_arg.as_ref().unwrap(); - - match first_generic_arg { - GenericArgument::Type(first_generic_arg_type) => Some(Self { - interface: first_generic_arg_type.clone(), - ptr: ptr.clone(), - }), - &_ => None, - } - } - &_ => None, - } - } -} diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs index d74acb3..6edcab3 100644 --- a/macros/src/injectable_impl.rs +++ b/macros/src/injectable_impl.rs @@ -1,20 +1,20 @@ +use std::error::Error; + use quote::{format_ident, quote, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::Generics; -use syn::{ - parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, ItemImpl, - Path, Type, TypePath, -}; +use syn::{parse_str, ExprMethodCall, FnArg, ItemImpl, Type}; -use crate::dependency_type::DependencyType; -use crate::util::item_impl::find_impl_method_by_name; +use crate::dependency::Dependency; +use crate::util::item_impl::find_impl_method_by_name_mut; +use crate::util::syn_path::syn_path_to_string; const DI_CONTAINER_VAR_NAME: &str = "di_container"; const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history"; pub struct InjectableImpl { - pub dependency_types: Vec, + pub dependencies: Vec, pub self_type: Type, pub generics: Generics, pub original_impl: ItemImpl, @@ -24,13 +24,13 @@ impl Parse for InjectableImpl { fn parse(input: ParseStream) -> syn::Result { - let impl_parsed_input = input.parse::()?; + let mut impl_parsed_input = input.parse::()?; - let dependency_types = Self::get_dependency_types(&impl_parsed_input) + let dependencies = Self::build_dependencies(&mut impl_parsed_input) .map_err(|err| input.error(err))?; Ok(Self { - dependency_types, + dependencies, self_type: impl_parsed_input.self_ty.as_ref().clone(), generics: impl_parsed_input.generics.clone(), original_impl: impl_parsed_input, @@ -43,7 +43,7 @@ impl InjectableImpl pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream { let Self { - dependency_types, + dependencies, self_type, generics, original_impl, @@ -52,7 +52,7 @@ impl InjectableImpl let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME); let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME); - let get_dep_method_calls = Self::create_get_dep_method_calls(dependency_types); + let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies); let maybe_doc_hidden = if no_doc_hidden { quote! {} @@ -111,46 +111,38 @@ impl InjectableImpl } fn create_get_dep_method_calls( - dependency_types: &[DependencyType], + dependencies: &[Dependency], ) -> Vec { - dependency_types + dependencies .iter() - .filter_map(|dep_type| match &dep_type.interface { - Type::TraitObject(dep_type_trait) => { - let method_call = parse_str::( - format!( - "{}.get_bound::<{}>({}.clone())", - DI_CONTAINER_VAR_NAME, - dep_type_trait.to_token_stream(), - DEPENDENCY_HISTORY_VAR_NAME - ) - .as_str(), - ) - .ok()?; - - Some((method_call, dep_type)) - - /* - */ - } - Type::Path(dep_type_path) => { - let dep_type_path_str = Self::path_to_string(&dep_type_path.path); - - let method_call: ExprMethodCall = parse_str( - format!( - "{}.get_bound::<{}>({}.clone())", - DI_CONTAINER_VAR_NAME, - dep_type_path_str, - DEPENDENCY_HISTORY_VAR_NAME + .filter_map(|dependency| { + let dep_interface_str = match &dependency.interface { + Type::TraitObject(interface_trait) => { + Some(interface_trait.to_token_stream().to_string()) + } + Type::Path(path_interface) => { + Some(syn_path_to_string(&path_interface.path)) + } + &_ => None, + }?; + + let method_call = parse_str::( + format!( + "{}.get_bound::<{}>({}.clone(), {})", + DI_CONTAINER_VAR_NAME, + dep_interface_str, + DEPENDENCY_HISTORY_VAR_NAME, + dependency.name.as_ref().map_or_else( + || "None".to_string(), + |name| format!("Some(\"{}\")", name.value()) ) - .as_str(), ) - .ok()?; + .as_str(), + ) + .ok()?; - Some((method_call, dep_type)) - } - &_ => None, + Some((method_call, dependency)) }) .map(|(method_call, dep_type)| { let ptr_name = dep_type.ptr.to_string(); @@ -168,95 +160,47 @@ impl InjectableImpl .collect() } - #[allow(clippy::match_wildcard_for_single_variants)] - fn get_has_fn_args_self(fn_args: &Punctuated) -> bool - { - fn_args.iter().any(|arg| match arg { - FnArg::Receiver(_) => true, - &_ => false, - }) - } - - fn get_fn_arg_type_paths(fn_args: &Punctuated) -> Vec<&TypePath> - { - fn_args - .iter() - .filter_map(|arg| match arg { - FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() { - Type::Path(arg_type_path) => Some(arg_type_path), - Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { - Type::Path(arg_type_path) => Some(arg_type_path), - &_ => None, - }, - &_ => None, - }, - FnArg::Receiver(_receiver_fn_arg) => None, - }) - .collect() - } - - fn path_to_string(path: &Path) -> String - { - path.segments - .pairs() - .fold(String::new(), |mut acc, segment_pair| { - let segment_ident = &segment_pair.value().ident; - - acc.push_str(segment_ident.to_string().as_str()); - - let opt_colon_two = segment_pair.punct(); - - match opt_colon_two { - Some(colon_two) => { - acc.push_str(colon_two.to_token_stream().to_string().as_str()); - } - None => {} - } - - acc - }) - } - - fn is_type_path_ptr(type_path: &TypePath) -> bool + fn build_dependencies( + item_impl: &mut ItemImpl, + ) -> Result, Box> { - let arg_type_path_string = Self::path_to_string(&type_path.path); - - arg_type_path_string == "TransientPtr" - || arg_type_path_string == "ptr::TransientPtr" - || arg_type_path_string == "syrrete::ptr::TransientPtr" - || arg_type_path_string == "SingletonPtr" - || arg_type_path_string == "ptr::SingletonPtr" - || arg_type_path_string == "syrrete::ptr::SingletonPtr" - || arg_type_path_string == "FactoryPtr" - || arg_type_path_string == "ptr::FactoryPtr" - || arg_type_path_string == "syrrete::ptr::FactoryPtr" - } - - fn get_dependency_types( - item_impl: &ItemImpl, - ) -> Result, &'static str> - { - let new_method_impl_item = find_impl_method_by_name(item_impl, "new") + let new_method_impl_item = find_impl_method_by_name_mut(item_impl, "new") .map_or_else(|| Err("Missing a 'new' method"), Ok)?; - let new_method_args = &new_method_impl_item.sig.inputs; + let new_method_args = &mut new_method_impl_item.sig.inputs; + + let dependencies: Result, _> = + new_method_args.iter().map(Dependency::build).collect(); + + for arg in new_method_args { + let typed_arg = if let FnArg::Typed(typed_arg) = arg { + typed_arg + } else { + continue; + }; + + let attrs_to_remove: Vec<_> = typed_arg + .attrs + .iter() + .enumerate() + .filter_map(|(index, attr)| { + if syn_path_to_string(&attr.path).as_str() == "syrette::named" { + return Some(index); + } - if Self::get_has_fn_args_self(new_method_args) { - return Err("Unexpected self argument in 'new' method"); - } + if attr.path.get_ident()?.to_string().as_str() == "named" { + return Some(index); + } - let new_method_arg_type_paths = Self::get_fn_arg_type_paths(new_method_args); + None + }) + .collect(); - if new_method_arg_type_paths - .iter() - .any(|arg_type_path| !Self::is_type_path_ptr(arg_type_path)) - { - return Err("All argument types in 'new' method must ptr types"); + for attr_index in attrs_to_remove { + typed_arg.attrs.remove(attr_index); + } } - Ok(new_method_arg_type_paths - .iter() - .filter_map(|arg_type_path| DependencyType::from_type_path(arg_type_path)) - .collect()) + dependencies } } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 9b97be6..c7157c8 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -9,11 +9,12 @@ use quote::quote; use syn::{parse, parse_macro_input}; mod declare_interface_args; -mod dependency_type; +mod dependency; mod factory_type_alias; mod injectable_impl; mod injectable_macro_args; mod libs; +mod named_attr_input; mod util; use declare_interface_args::DeclareInterfaceArgs; @@ -38,15 +39,60 @@ use libs::intertrait_macros::gen_caster::generate_caster; /// declare the interface with the [`declare_interface!`] macro or use /// the [`di_container_bind`] macro to create a DI container binding. /// -/// # Example +/// # Attributes +/// Attributes specific to impls with this attribute macro. +/// +/// ### Named +/// Used inside the `new` method before a dependency argument. Declares the name of the +/// dependency. Should be given the name quoted inside parenthesis. +/// +/// The [`macro@named`] ghost attribute macro can be used for intellisense and +/// autocompletion for this attribute. +/// +/// For example: /// ``` -/// use syrette::injectable; +/// # use syrette::ptr::TransientPtr; +/// # use syrette::injectable; +/// # +/// # trait IArmor {} +/// # +/// # trait IKnight {} +/// # +/// # struct Knight +/// # { +/// # tough_armor: TransientPtr, +/// # light_armor: TransientPtr, +/// # } +/// # +/// #[injectable(IKnight)] +/// impl Knight +/// { +/// pub fn new( +/// #[named("tough")] +/// tough_armor: TransientPtr, /// -/// struct PasswordManager {} +/// #[named("light")] +/// light_armor: TransientPtr +/// ) -> Self +/// { +/// Self { tough_armor, light_armor } +/// } +/// } +/// # +/// # impl IKnight for Knight {} +/// ``` /// +/// # Example +/// ``` +/// # use syrette::injectable; +/// # +/// # struct PasswordManager {} +/// # /// #[injectable] -/// impl PasswordManager { -/// pub fn new() -> Self { +/// impl PasswordManager +/// { +/// pub fn new() -> Self +/// { /// Self {} /// } /// } @@ -191,3 +237,52 @@ pub fn declare_interface(input: TokenStream) -> TokenStream generate_caster(&implementation, &interface).into() } + +/// Declares the name of a dependency. +/// +/// This macro attribute doesn't actually do anything. It only exists for the +/// convenience of having intellisense and autocompletion. +/// You might as well just use `named` if you don't care about that. +/// +/// Only means something inside a `new` method inside a impl with +/// the [`macro@injectable`] macro attribute. +/// +/// # Examples +/// ``` +/// # use syrette::ptr::TransientPtr; +/// # use syrette::injectable; +/// # +/// # trait INinja {} +/// # trait IWeapon {} +/// # +/// # struct Ninja +/// # { +/// # strong_weapon: TransientPtr, +/// # weak_weapon: TransientPtr, +/// # } +/// # +/// #[injectable(INinja)] +/// impl Ninja +/// { +/// pub fn new( +/// #[syrette::named("strong")] +/// strong_weapon: TransientPtr, +/// +/// #[syrette::named("weak")] +/// weak_weapon: TransientPtr, +/// ) -> Self +/// { +/// Self { +/// strong_weapon, +/// weak_weapon, +/// } +/// } +/// } +/// # +/// # impl INinja for Ninja {} +/// ``` +#[proc_macro_attribute] +pub fn named(_: TokenStream, _: TokenStream) -> TokenStream +{ + TokenStream::new() +} diff --git a/macros/src/named_attr_input.rs b/macros/src/named_attr_input.rs new file mode 100644 index 0000000..5f7123c --- /dev/null +++ b/macros/src/named_attr_input.rs @@ -0,0 +1,21 @@ +use syn::parse::Parse; +use syn::{parenthesized, LitStr}; + +pub struct NamedAttrInput +{ + pub name: LitStr, +} + +impl Parse for NamedAttrInput +{ + fn parse(input: syn::parse::ParseStream) -> syn::Result + { + let content; + + parenthesized!(content in input); + + Ok(Self { + name: content.parse()?, + }) + } +} diff --git a/macros/src/util/item_impl.rs b/macros/src/util/item_impl.rs index 271ae2f..4bd7492 100644 --- a/macros/src/util/item_impl.rs +++ b/macros/src/util/item_impl.rs @@ -1,13 +1,13 @@ use syn::{ImplItem, ImplItemMethod, ItemImpl}; -pub fn find_impl_method_by_name<'item_impl>( - item_impl: &'item_impl ItemImpl, +pub fn find_impl_method_by_name_mut<'item_impl>( + item_impl: &'item_impl mut ItemImpl, method_name: &'static str, -) -> Option<&'item_impl ImplItemMethod> +) -> Option<&'item_impl mut ImplItemMethod> { - let impl_items = &item_impl.items; + let impl_items = &mut item_impl.items; - impl_items.iter().find_map(|impl_item| match impl_item { + impl_items.iter_mut().find_map(|impl_item| match impl_item { ImplItem::Method(method_item) => { if method_item.sig.ident == method_name { Some(method_item) @@ -15,6 +15,6 @@ pub fn find_impl_method_by_name<'item_impl>( None } } - &_ => None, + &mut _ => None, }) } diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs index fc7b2c6..4f2a594 100644 --- a/macros/src/util/mod.rs +++ b/macros/src/util/mod.rs @@ -1,2 +1,3 @@ pub mod item_impl; pub mod iterator_ext; +pub mod syn_path; diff --git a/macros/src/util/syn_path.rs b/macros/src/util/syn_path.rs new file mode 100644 index 0000000..15653bf --- /dev/null +++ b/macros/src/util/syn_path.rs @@ -0,0 +1,22 @@ +#![allow(clippy::module_name_repetitions)] +use quote::ToTokens; +use syn::punctuated::Pair; + +pub fn syn_path_to_string(path: &syn::Path) -> String +{ + path.segments + .pairs() + .map(Pair::into_tuple) + .map(|(segment, opt_punct)| { + let segment_ident = &segment.ident; + + format!( + "{}{}", + segment_ident, + opt_punct.map_or_else(String::new, |punct| punct + .to_token_stream() + .to_string()) + ) + }) + .collect() +} diff --git a/src/di_container.rs b/src/di_container.rs index 85b0e7a..9d54261 100644 --- a/src/di_container.rs +++ b/src/di_container.rs @@ -54,13 +54,66 @@ use std::marker::PhantomData; use crate::castable_factory::CastableFactory; use crate::di_container_binding_map::DIContainerBindingMap; use crate::errors::di_container::{ - BindingBuilderError, BindingScopeConfiguratorError, DIContainerError, + BindingBuilderError, BindingScopeConfiguratorError, BindingWhenConfiguratorError, + DIContainerError, }; use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast::{CastBox, CastRc}; use crate::provider::{Providable, SingletonProvider, TransientTypeProvider}; use crate::ptr::{SingletonPtr, SomePtr}; +/// When configurator for a binding for type 'Interface' inside a [`DIContainer`]. +pub struct BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + di_container: &'di_container mut DIContainer, + interface_phantom: PhantomData, +} + +impl<'di_container, Interface> BindingWhenConfigurator<'di_container, Interface> +where + Interface: 'static + ?Sized, +{ + fn new(di_container: &'di_container mut DIContainer) -> Self + { + Self { + di_container, + interface_phantom: PhantomData, + } + } + + /// Configures the binding to have a name. + /// + /// # Errors + /// Will return Err if no binding for the interface already exists. + pub fn when_named( + &mut self, + name: &'static str, + ) -> Result<(), BindingWhenConfiguratorError> + { + let binding = self + .di_container + .bindings + .remove::(None) + .map_or_else( + || { + Err(BindingWhenConfiguratorError::BindingNotFound(type_name::< + Interface, + >( + ))) + }, + Ok, + )?; + + self.di_container + .bindings + .set::(Some(name), binding); + + Ok(()) + } +} + /// Scope configurator for a binding for type 'Interface' inside a [`DIContainer`]. pub struct BindingScopeConfigurator<'di_container, Interface, Implementation> where @@ -90,18 +143,23 @@ where /// Configures the binding to be in a transient scope. /// /// This is the default. - pub fn in_transient_scope(&mut self) + pub fn in_transient_scope(&mut self) -> BindingWhenConfigurator { - self.di_container - .bindings - .set::(Box::new(TransientTypeProvider::::new())); + self.di_container.bindings.set::( + None, + Box::new(TransientTypeProvider::::new()), + ); + + BindingWhenConfigurator::new(self.di_container) } /// Configures the binding to be in a singleton scope. /// /// # Errors /// Will return Err if resolving the implementation fails. - pub fn in_singleton_scope(&mut self) -> Result<(), BindingScopeConfiguratorError> + pub fn in_singleton_scope( + &mut self, + ) -> Result, BindingScopeConfiguratorError> { let singleton: SingletonPtr = SingletonPtr::from( Implementation::resolve(self.di_container, Vec::new()) @@ -110,9 +168,9 @@ where self.di_container .bindings - .set::(Box::new(SingletonProvider::new(singleton))); + .set::(None, Box::new(SingletonProvider::new(singleton))); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -152,7 +210,7 @@ where where Implementation: Injectable, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -178,13 +236,13 @@ where pub fn to_factory( &mut self, factory_func: &'static dyn Fn>, - ) -> Result<(), BindingBuilderError> + ) -> Result, BindingBuilderError> where Args: 'static, Return: 'static + ?Sized, Interface: crate::interfaces::factory::IFactory, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -192,13 +250,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } /// Creates a binding of type `Interface` to a factory that takes no arguments @@ -213,11 +272,11 @@ where pub fn to_default_factory( &mut self, factory_func: &'static dyn Fn<(), Output = crate::ptr::TransientPtr>, - ) -> Result<(), BindingBuilderError> + ) -> Result, BindingBuilderError> where Return: 'static + ?Sized, { - if self.di_container.bindings.has::() { + if self.di_container.bindings.has::(None) { return Err(BindingBuilderError::BindingAlreadyExists(type_name::< Interface, >())); @@ -225,13 +284,14 @@ where let factory_impl = CastableFactory::new(factory_func); - self.di_container.bindings.set::(Box::new( - crate::provider::FactoryProvider::new(crate::ptr::FactoryPtr::new( - factory_impl, + self.di_container.bindings.set::( + None, + Box::new(crate::provider::FactoryProvider::new( + crate::ptr::FactoryPtr::new(factory_impl), )), - )); + ); - Ok(()) + Ok(BindingWhenConfigurator::new(self.di_container)) } } @@ -265,26 +325,53 @@ impl DIContainer /// # Errors /// Will return `Err` if: /// - No binding for `Interface` exists - /// - Resolving the binding for `Interface` fails - /// - Casting the binding for `Interface` fails + /// - Resolving the binding for fails + /// - Casting the binding for fails pub fn get(&self) -> Result, DIContainerError> where Interface: 'static + ?Sized, { - self.get_bound::(Vec::new()) + self.get_bound::(Vec::new(), None) + } + + /// Returns the type bound with `Interface` and the specified name. + /// + /// # Errors + /// Will return `Err` if: + /// - No binding for `Interface` with name `name` exists + /// - Resolving the binding fails + /// - Casting the binding for fails + pub fn get_named( + &self, + name: &'static str, + ) -> Result, DIContainerError> + where + Interface: 'static + ?Sized, + { + self.get_bound::(Vec::new(), Some(name)) } #[doc(hidden)] pub fn get_bound( &self, dependency_history: Vec<&'static str>, + name: Option<&'static str>, ) -> Result, DIContainerError> where Interface: 'static + ?Sized, { let binding_providable = - self.get_binding_providable::(dependency_history)?; + self.get_binding_providable::(name, dependency_history)?; + + Self::handle_binding_providable(binding_providable) + } + fn handle_binding_providable( + binding_providable: Providable, + ) -> Result, DIContainerError> + where + Interface: 'static + ?Sized, + { match binding_providable { Providable::Transient(transient_binding) => Ok(SomePtr::Transient( transient_binding.cast::().map_err(|_| { @@ -318,13 +405,14 @@ impl DIContainer fn get_binding_providable( &self, + name: Option<&'static str>, dependency_history: Vec<&'static str>, ) -> Result where Interface: 'static + ?Sized, { self.bindings - .get::()? + .get::(name)? .provide(self, dependency_history) .map_err(|err| DIContainerError::BindingResolveFailed { reason: err, @@ -493,6 +581,41 @@ mod tests Ok(()) } + #[test] + fn can_bind_to_transient() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_transient_scope(); + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + + #[test] + fn can_bind_to_transient_when_named() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_transient_scope() + .when_named("regular")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] fn can_bind_to_singleton() -> Result<(), Box> { @@ -510,6 +633,24 @@ mod tests Ok(()) } + #[test] + fn can_bind_to_singleton_when_named() -> Result<(), Box> + { + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to::()? + .in_singleton_scope()? + .when_named("cool")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_bind_to_factory() -> Result<(), Box> @@ -533,6 +674,32 @@ mod tests Ok(()) } + #[test] + #[cfg(feature = "factory")] + fn can_bind_to_factory_when_named() -> Result<(), Box> + { + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(), dyn subjects::IUserManager>; + + let mut di_container: DIContainer = DIContainer::new(); + + assert_eq!(di_container.bindings.count(), 0); + + di_container + .bind::() + .to_factory(&|| { + let user_manager: TransientPtr = + TransientPtr::new(subjects::UserManager::new()); + + user_manager + })? + .when_named("awesome")?; + + assert_eq!(di_container.bindings.count(), 1); + + Ok(()) + } + #[test] fn can_get() -> Result<(), Box> { @@ -561,9 +728,48 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); + + di_container + .get::()? + .transient()?; + + Ok(()) + } + + #[test] + fn can_get_named() -> Result<(), Box> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Transient(TransientPtr::new( + subjects::UserManager::new(), + ))) + }); + + di_container + .bindings + .set::(Some("special"), Box::new(mock_provider)); - di_container.get::()?; + di_container + .get_named::("special")? + .transient()?; Ok(()) } @@ -598,7 +804,7 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); let first_number_rc = di_container.get::()?.singleton()?; @@ -612,6 +818,53 @@ mod tests Ok(()) } + #[test] + fn can_get_singleton_named() -> Result<(), Box> + { + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + let mut singleton = SingletonPtr::new(subjects::Number::new()); + + SingletonPtr::get_mut(&mut singleton).unwrap().num = 2820; + + mock_provider + .expect_provide() + .returning_st(move |_, _| Ok(Providable::Singleton(singleton.clone()))); + + di_container + .bindings + .set::(Some("cool"), Box::new(mock_provider)); + + let first_number_rc = di_container + .get_named::("cool")? + .singleton()?; + + assert_eq!(first_number_rc.get(), 2820); + + let second_number_rc = di_container + .get_named::("cool")? + .singleton()?; + + assert_eq!(first_number_rc.as_ref(), second_number_rc.as_ref()); + + Ok(()) + } + #[test] #[cfg(feature = "factory")] fn can_get_factory() -> Result<(), Box> @@ -688,9 +941,94 @@ mod tests di_container .bindings - .set::(Box::new(mock_provider)); + .set::(None, Box::new(mock_provider)); + + di_container.get::()?.factory()?; - di_container.get::()?; + Ok(()) + } + + #[test] + #[cfg(feature = "factory")] + fn can_get_factory_named() -> Result<(), Box> + { + trait IUserManager + { + fn add_user(&mut self, user_id: i128); + + fn remove_user(&mut self, user_id: i128); + } + + struct UserManager + { + users: Vec, + } + + impl UserManager + { + fn new(users: Vec) -> Self + { + Self { users } + } + } + + impl IUserManager for UserManager + { + fn add_user(&mut self, user_id: i128) + { + self.users.push(user_id); + } + + fn remove_user(&mut self, user_id: i128) + { + let user_index = + self.users.iter().position(|user| *user == user_id).unwrap(); + + self.users.remove(user_index); + } + } + + use crate as syrette; + + #[crate::factory] + type IUserManagerFactory = + dyn crate::interfaces::factory::IFactory<(Vec,), dyn IUserManager>; + + mock! { + Provider {} + + impl IProvider for Provider + { + fn provide( + &self, + di_container: &DIContainer, + dependency_history: Vec<&'static str>, + ) -> Result; + } + } + + let mut di_container: DIContainer = DIContainer::new(); + + let mut mock_provider = MockProvider::new(); + + mock_provider.expect_provide().returning(|_, _| { + Ok(Providable::Factory(crate::ptr::FactoryPtr::new( + CastableFactory::new(&|users| { + let user_manager: TransientPtr = + TransientPtr::new(UserManager::new(users)); + + user_manager + }), + ))) + }); + + di_container + .bindings + .set::(Some("special"), Box::new(mock_provider)); + + di_container + .get_named::("special")? + .factory()?; Ok(()) } diff --git a/src/di_container_binding_map.rs b/src/di_container_binding_map.rs index 20d040f..d4b46f2 100644 --- a/src/di_container_binding_map.rs +++ b/src/di_container_binding_map.rs @@ -4,9 +4,16 @@ use ahash::AHashMap; use crate::{errors::di_container::DIContainerError, provider::IProvider}; +#[derive(Debug, PartialEq, Eq, Hash)] +struct DIContainerBindingKey +{ + type_id: TypeId, + name: Option<&'static str>, +} + pub struct DIContainerBindingMap { - bindings: AHashMap>, + bindings: AHashMap>, } impl DIContainerBindingMap @@ -18,7 +25,10 @@ impl DIContainerBindingMap } } - pub fn get(&self) -> Result<&dyn IProvider, DIContainerError> + pub fn get( + &self, + name: Option<&'static str>, + ) -> Result<&dyn IProvider, DIContainerError> where Interface: 'static + ?Sized, { @@ -26,27 +36,60 @@ impl DIContainerBindingMap Ok(self .bindings - .get(&interface_typeid) - .ok_or_else(|| DIContainerError::BindingNotFound(type_name::()))? + .get(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) + .ok_or_else(|| DIContainerError::BindingNotFound { + interface: type_name::(), + name, + })? .as_ref()) } - pub fn set(&mut self, provider: Box) + pub fn set( + &mut self, + name: Option<&'static str>, + provider: Box, + ) where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::(); + + self.bindings.insert( + DIContainerBindingKey { + type_id: interface_typeid, + name, + }, + provider, + ); + } + + pub fn remove( + &mut self, + name: Option<&'static str>, + ) -> Option> where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::(); - self.bindings.insert(interface_typeid, provider); + self.bindings.remove(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } - pub fn has(&self) -> bool + pub fn has(&self, name: Option<&'static str>) -> bool where Interface: 'static + ?Sized, { let interface_typeid = TypeId::of::(); - self.bindings.contains_key(&interface_typeid) + self.bindings.contains_key(&DIContainerBindingKey { + type_id: interface_typeid, + name, + }) } /// Only used by tests in the `di_container` module. diff --git a/src/errors/di_container.rs b/src/errors/di_container.rs index 65cd9d1..82a3d55 100644 --- a/src/errors/di_container.rs +++ b/src/errors/di_container.rs @@ -26,9 +26,19 @@ pub enum DIContainerError interface: &'static str, }, - /// No binding exists for a interface. - #[error("No binding exists for interface '{0}'")] - BindingNotFound(&'static str), + /// No binding exists for a interface (and optionally a name). + #[error( + "No binding exists for interface '{interface}' {}", + .name.map_or_else(String::new, |name| format!("with name '{}'", name)) + )] + BindingNotFound + { + /// The interface that doesn't have a binding. + interface: &'static str, + + /// The name of the binding if one exists. + name: Option<&'static str>, + }, } /// Error type for [`BindingBuilder`]. @@ -52,3 +62,14 @@ pub enum BindingScopeConfiguratorError #[error("Resolving the given singleton failed")] SingletonResolveFailed(#[from] InjectableError), } + +/// Error type for [`BindingWhenConfigurator`]. +/// +/// [`BindingWhenConfigurator`]: crate::di_container::BindingWhenConfigurator +#[derive(thiserror::Error, Debug)] +pub enum BindingWhenConfiguratorError +{ + /// A binding for a interface wasn't found. + #[error("A binding for interface '{0}' wasn't found'")] + BindingNotFound(&'static str), +} -- cgit v1.2.3-18-g5258