diff options
author | HampusM <hampus@hampusmat.com> | 2022-07-16 12:02:54 +0200 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2022-07-16 12:02:54 +0200 |
commit | 05be92b334af1beab3e7a3f2ee7626eb26c47e22 (patch) | |
tree | 43883a89985b29721961f2001a88db9985bd3485 | |
parent | 5129384fc0b6f51d315fd528d7769dd638018b88 (diff) |
feat: add binding factories to DI container
-rw-r--r-- | example/src/main.rs | 43 | ||||
-rw-r--r-- | syrette/src/castable_factory.rs | 72 | ||||
-rw-r--r-- | syrette/src/di_container.rs | 132 | ||||
-rw-r--r-- | syrette/src/interfaces/factory.rs | 7 | ||||
-rw-r--r-- | syrette/src/interfaces/mod.rs | 1 | ||||
-rw-r--r-- | syrette/src/lib.rs | 3 | ||||
-rw-r--r-- | syrette/src/libs/intertrait/cast_rc.rs | 34 | ||||
-rw-r--r-- | syrette/src/libs/intertrait/mod.rs | 1 | ||||
-rw-r--r-- | syrette/src/provider.rs | 48 | ||||
-rw-r--r-- | syrette_macros/src/lib.rs | 149 |
10 files changed, 441 insertions, 49 deletions
diff --git a/example/src/main.rs b/example/src/main.rs index 1f23179..28f29b4 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -1,4 +1,8 @@ +use std::rc::Rc; + use syrette::errors::di_container::DIContainerError; +use syrette::factory; +use syrette::interfaces::factory::IFactory; use syrette::{injectable, DIContainer}; trait IDog @@ -54,6 +58,32 @@ trait ICow fn moo(&self); } +struct Cow +{ + _moo_cnt: i32, +} + +impl Cow +{ + fn new(moo_cnt: i32) -> Self + { + Self { _moo_cnt: moo_cnt } + } +} + +impl ICow for Cow +{ + fn moo(&self) + { + for _ in 0..self._moo_cnt { + println!("Moo"); + } + } +} + +#[factory] +type CowFactory = dyn IFactory<(i32,), dyn ICow>; + trait IHuman { fn make_pets_make_sounds(&self); @@ -63,16 +93,18 @@ struct Human { _dog: Box<dyn IDog>, _cat: Box<dyn ICat>, + _cow_factory: Rc<CowFactory>, } #[injectable(IHuman)] impl Human { - fn new(dog: Box<dyn IDog>, cat: Box<dyn ICat>) -> Self + fn new(dog: Box<dyn IDog>, cat: Box<dyn ICat>, cow_factory: Rc<CowFactory>) -> Self { Self { _dog: dog, _cat: cat, + _cow_factory: cow_factory, } } } @@ -88,6 +120,10 @@ impl IHuman for Human println!("Hi kitty!"); self._cat.meow(); + + let cow: Box<dyn ICow> = (self._cow_factory)(3); + + cow.moo(); } } @@ -101,6 +137,11 @@ fn main() -> error_stack::Result<(), DIContainerError> di_container.bind::<dyn ICat>().to::<Cat>(); di_container.bind::<dyn IHuman>().to::<Human>(); + di_container.bind::<CowFactory>().to_factory(&|moo_cnt| { + let cow: Box<dyn ICow> = Box::new(Cow::new(moo_cnt)); + cow + }); + let dog = di_container.get::<dyn IDog>()?; dog.woof(); diff --git a/syrette/src/castable_factory.rs b/syrette/src/castable_factory.rs new file mode 100644 index 0000000..8713ec4 --- /dev/null +++ b/syrette/src/castable_factory.rs @@ -0,0 +1,72 @@ +use crate::interfaces::factory::IFactory; +use crate::libs::intertrait::CastFrom; + +pub trait AnyFactory: CastFrom {} + +pub struct CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ + _func: &'static dyn Fn<Args, Output = Box<Return>>, +} + +impl<Args, Return> CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ + pub fn new(func: &'static dyn Fn<Args, Output = Box<Return>>) -> Self + { + Self { _func: func } + } +} + +impl<Args, Return> IFactory<Args, Return> for CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ +} + +impl<Args, Return> Fn<Args> for CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ + extern "rust-call" fn call(&self, args: Args) -> Self::Output + { + self._func.call(args) + } +} + +impl<Args, Return> FnMut<Args> for CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ + extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output + { + self.call(args) + } +} + +impl<Args, Return> FnOnce<Args> for CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ + type Output = Box<Return>; + + extern "rust-call" fn call_once(self, args: Args) -> Self::Output + { + self.call(args) + } +} + +impl<Args, Return> AnyFactory for CastableFactory<Args, Return> +where + Args: 'static, + Return: 'static + ?Sized, +{ +} diff --git a/syrette/src/di_container.rs b/syrette/src/di_container.rs index 32d53f2..53c4287 100644 --- a/syrette/src/di_container.rs +++ b/syrette/src/di_container.rs @@ -5,23 +5,26 @@ use std::rc::Rc; use error_stack::{Report, ResultExt}; +use crate::castable_factory::CastableFactory; use crate::errors::di_container::DIContainerError; +use crate::interfaces::factory::IFactory; use crate::interfaces::injectable::Injectable; use crate::libs::intertrait::cast_box::CastBox; -use crate::provider::{IInjectableTypeProvider, InjectableTypeProvider}; +use crate::libs::intertrait::cast_rc::CastRc; +use crate::provider::{FactoryProvider, IProvider, InjectableTypeProvider, Providable}; -/// Binding builder for `InterfaceTrait` in a [`DIContainer`]. -pub struct BindingBuilder<'a, InterfaceTrait> +/// Binding builder for type `Interface` inside a [`DIContainer`]. +pub struct BindingBuilder<'a, Interface> where - InterfaceTrait: 'static + ?Sized, + Interface: 'static + ?Sized, { _di_container: &'a mut DIContainer, - _phantom_data: PhantomData<InterfaceTrait>, + _phantom_data: PhantomData<Interface>, } -impl<'a, InterfaceTrait> BindingBuilder<'a, InterfaceTrait> +impl<'a, Interface> BindingBuilder<'a, Interface> where - InterfaceTrait: 'static + ?Sized, + Interface: 'static + ?Sized, { fn new(di_container: &'a mut DIContainer) -> Self { @@ -31,19 +34,39 @@ where } } - /// Creates a binding of `InterfaceTrait` to type `Implementation` inside of the + /// Creates a binding of type `Interface` to type `Implementation` inside of the /// associated [`DIContainer`]. pub fn to<Implementation>(&mut self) where Implementation: Injectable, { - let interface_typeid = TypeId::of::<InterfaceTrait>(); + let interface_typeid = TypeId::of::<Interface>(); self._di_container._bindings.insert( interface_typeid, Rc::new(InjectableTypeProvider::<Implementation>::new()), ); } + + /// Creates a binding of factory type `Interface` to a factory inside of the + /// associated [`DIContainer`]. + pub fn to_factory<Args, Return>( + &mut self, + factory_func: &'static dyn Fn<Args, Output = Box<Return>>, + ) where + Args: 'static, + Return: 'static + ?Sized, + Interface: IFactory<Args, Return>, + { + let interface_typeid = TypeId::of::<Interface>(); + + let factory_impl = CastableFactory::new(factory_func); + + self._di_container._bindings.insert( + interface_typeid, + Rc::new(FactoryProvider::new(Rc::new(factory_impl))), + ); + } } /// Dependency injection container. @@ -56,7 +79,7 @@ where /// ``` pub struct DIContainer { - _bindings: HashMap<TypeId, Rc<dyn IInjectableTypeProvider>>, + _bindings: HashMap<TypeId, Rc<dyn IProvider>>, } impl<'a> DIContainer @@ -69,46 +92,95 @@ impl<'a> DIContainer } } - /// Returns a new [`BindingBuilder`] for the given interface trait. - pub fn bind<InterfaceTrait>(&'a mut self) -> BindingBuilder<InterfaceTrait> + /// Returns a new [`BindingBuilder`] for the given interface. + pub fn bind<Interface>(&'a mut self) -> BindingBuilder<Interface> where - InterfaceTrait: 'static + ?Sized, + Interface: 'static + ?Sized, { - BindingBuilder::<InterfaceTrait>::new(self) + BindingBuilder::<Interface>::new(self) } - /// Returns the value bound with `InterfaceTrait`. - pub fn get<InterfaceTrait>( - &self, - ) -> error_stack::Result<Box<InterfaceTrait>, DIContainerError> + /// Returns a new instance of the type bound with `Interface`. + pub fn get<Interface>(&self) -> error_stack::Result<Box<Interface>, DIContainerError> where - InterfaceTrait: 'static + ?Sized, + Interface: 'static + ?Sized, { - let interface_typeid = TypeId::of::<InterfaceTrait>(); + let interface_typeid = TypeId::of::<Interface>(); - let interface_name = type_name::<InterfaceTrait>(); + let interface_name = type_name::<Interface>(); let binding = self._bindings.get(&interface_typeid).ok_or_else(|| { Report::new(DIContainerError) .attach_printable(format!("No binding exists for {}", interface_name)) })?; - let binding_injectable = binding + let binding_providable = binding .provide(self) .change_context(DIContainerError) .attach_printable(format!( - "Failed to resolve interface {}", + "Failed to resolve binding for interface {}", interface_name ))?; - let interface_box_result = binding_injectable.cast::<InterfaceTrait>(); + match binding_providable { + Providable::Injectable(binding_injectable) => { + let interface_box_result = binding_injectable.cast::<Interface>(); + + match interface_box_result { + Ok(interface_box) => Ok(interface_box), + Err(_) => Err(Report::new(DIContainerError).attach_printable( + format!("Unable to cast binding for {}", interface_name), + )), + } + } + Providable::Factory(_) => Err(Report::new(DIContainerError) + .attach_printable(format!( + "Binding for {} is not injectable", + interface_name + ))), + } + } + + /// Returns the factory bound with factory type `Interface`. + pub fn get_factory<Interface>( + &self, + ) -> error_stack::Result<Rc<Interface>, DIContainerError> + where + Interface: 'static + ?Sized, + { + let interface_typeid = TypeId::of::<Interface>(); + + let interface_name = type_name::<Interface>(); + + let binding = self._bindings.get(&interface_typeid).ok_or_else(|| { + Report::new(DIContainerError) + .attach_printable(format!("No binding exists for {}", interface_name)) + })?; + + let binding_providable = binding + .provide(self) + .change_context(DIContainerError) + .attach_printable(format!( + "Failed to resolve binding for interface {}", + interface_name + ))?; - match interface_box_result { - Ok(interface_box) => Ok(interface_box), - Err(_) => Err(Report::new(DIContainerError).attach_printable(format!( - "Unable to cast binding for {}", - interface_name - ))), + match binding_providable { + Providable::Factory(binding_factory) => { + let factory_box_result = binding_factory.cast::<Interface>(); + + match factory_box_result { + Ok(interface_box) => Ok(interface_box), + Err(_) => Err(Report::new(DIContainerError).attach_printable( + format!("Unable to cast binding for {}", interface_name), + )), + } + } + Providable::Injectable(_) => Err(Report::new(DIContainerError) + .attach_printable(format!( + "Binding for {} is not a factory", + interface_name + ))), } } } diff --git a/syrette/src/interfaces/factory.rs b/syrette/src/interfaces/factory.rs new file mode 100644 index 0000000..ed03cce --- /dev/null +++ b/syrette/src/interfaces/factory.rs @@ -0,0 +1,7 @@ +use crate::libs::intertrait::CastFrom; + +pub trait IFactory<Args, Return>: Fn<Args, Output = Box<Return>> + CastFrom +where + Return: 'static + ?Sized, +{ +} diff --git a/syrette/src/interfaces/mod.rs b/syrette/src/interfaces/mod.rs index 31e53af..921bb9c 100644 --- a/syrette/src/interfaces/mod.rs +++ b/syrette/src/interfaces/mod.rs @@ -1 +1,2 @@ +pub mod factory; pub mod injectable; diff --git a/syrette/src/lib.rs b/syrette/src/lib.rs index 7278c37..945c0c0 100644 --- a/syrette/src/lib.rs +++ b/syrette/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(unboxed_closures, fn_traits)] + //! Syrette //! //! Syrette is a collection of utilities useful for performing dependency injection. @@ -115,6 +117,7 @@ //! //! ``` +pub mod castable_factory; pub mod di_container; pub mod errors; pub mod interfaces; diff --git a/syrette/src/libs/intertrait/cast_rc.rs b/syrette/src/libs/intertrait/cast_rc.rs new file mode 100644 index 0000000..58d212a --- /dev/null +++ b/syrette/src/libs/intertrait/cast_rc.rs @@ -0,0 +1,34 @@ +/** + * Originally from Intertrait by CodeChain + * + * https://github.com/CodeChain-io/intertrait + * https://crates.io/crates/intertrait/0.2.2 + * + * Licensed under either of + * + * Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT) + + * at your option. +*/ +use std::rc::Rc; + +use crate::libs::intertrait::{caster, CastFrom}; + +pub trait CastRc +{ + /// Casts an `Rc` for this trait into that for type `T`. + fn cast<T: ?Sized + 'static>(self: Rc<Self>) -> Result<Rc<T>, Rc<Self>>; +} + +/// A blanket implementation of `CastRc` for traits extending `CastFrom`. +impl<S: ?Sized + CastFrom> CastRc for S +{ + fn cast<T: ?Sized + 'static>(self: Rc<Self>) -> Result<Rc<T>, Rc<Self>> + { + match caster::<T>((*self).type_id()) { + Some(caster) => Ok((caster.cast_rc)(self.rc_any())), + None => Err(self), + } + } +} diff --git a/syrette/src/libs/intertrait/mod.rs b/syrette/src/libs/intertrait/mod.rs index b07d91e..e7b3bdd 100644 --- a/syrette/src/libs/intertrait/mod.rs +++ b/syrette/src/libs/intertrait/mod.rs @@ -27,6 +27,7 @@ mod hasher; use hasher::BuildFastHasher; pub mod cast_box; +pub mod cast_rc; pub type BoxedCaster = Box<dyn Any + Send + Sync>; diff --git a/syrette/src/provider.rs b/syrette/src/provider.rs index 0d6a1cc..800315f 100644 --- a/syrette/src/provider.rs +++ b/syrette/src/provider.rs @@ -1,17 +1,25 @@ use std::marker::PhantomData; +use std::rc::Rc; -extern crate error_stack; - +use crate::castable_factory::AnyFactory; use crate::errors::injectable::ResolveError; use crate::interfaces::injectable::Injectable; use crate::DIContainer; -pub trait IInjectableTypeProvider +extern crate error_stack; + +pub enum Providable +{ + Injectable(Box<dyn Injectable>), + Factory(Rc<dyn AnyFactory>), +} + +pub trait IProvider { fn provide( &self, di_container: &DIContainer, - ) -> error_stack::Result<Box<dyn Injectable>, ResolveError>; + ) -> error_stack::Result<Providable, ResolveError>; } pub struct InjectableTypeProvider<InjectableType> @@ -33,15 +41,41 @@ where } } -impl<InjectableType> IInjectableTypeProvider for InjectableTypeProvider<InjectableType> +impl<InjectableType> IProvider for InjectableTypeProvider<InjectableType> where InjectableType: Injectable, { fn provide( &self, di_container: &DIContainer, - ) -> error_stack::Result<Box<dyn Injectable>, ResolveError> + ) -> error_stack::Result<Providable, ResolveError> + { + Ok(Providable::Injectable(InjectableType::resolve( + di_container, + )?)) + } +} + +pub struct FactoryProvider +{ + _factory: Rc<dyn AnyFactory>, +} + +impl FactoryProvider +{ + pub fn new(factory: Rc<dyn AnyFactory>) -> Self + { + Self { _factory: factory } + } +} + +impl IProvider for FactoryProvider +{ + fn provide( + &self, + _di_container: &DIContainer, + ) -> error_stack::Result<Providable, ResolveError> { - Ok(InjectableType::resolve(di_container)?) + Ok(Providable::Factory(self._factory.clone())) } } diff --git a/syrette_macros/src/lib.rs b/syrette_macros/src/lib.rs index 1761534..0302c07 100644 --- a/syrette_macros/src/lib.rs +++ b/syrette_macros/src/lib.rs @@ -2,8 +2,8 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma, - AttributeArgs, FnArg, GenericArgument, ImplItem, ItemImpl, Meta, NestedMeta, Path, - PathArguments, Type, TypePath, + AttributeArgs, ExprMethodCall, FnArg, GenericArgument, ImplItem, ItemImpl, ItemType, + Meta, NestedMeta, Path, PathArguments, Type, TypeParamBound, TypePath, }; mod libs; @@ -27,8 +27,16 @@ const IMPL_NO_NEW_METHOD_ERR_MESSAGE: &str = const IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE: &str = "The new method of the attached to trait implementation cannot have a self parameter"; -const IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE: &str = - "All parameters of the new method of the attached to trait implementation must be std::boxed::Box"; +const IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE: &str = concat!( + "All parameters of the new method of the attached to trait implementation ", + "must be either std::boxed::Box or std::rc:Rc (for factories)" +); + +const INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE: &str = + "Invalid aliased trait. Must be 'dyn IFactory'"; + +const INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE: &str = + "Invalid arguments for 'dyn IFactory'"; fn path_to_string(path: &Path) -> String { @@ -67,6 +75,10 @@ fn get_fn_arg_type_paths(fn_args: &Punctuated<FnArg, Comma>) -> Vec<TypePath> match arg { FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() { Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), + &_ => {} + }, &_ => {} }, FnArg::Receiver(_receiver_fn_arg) => {} @@ -109,8 +121,11 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec<Type> if arg_type_path_string != "Box" && arg_type_path_string != "std::boxed::Box" && arg_type_path_string != "boxed::Box" + && arg_type_path_string != "Rc" + && arg_type_path_string != "std::rc::Rc" + && arg_type_path_string != "rc::Rc" { - panic!("{}", IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE); + panic!("{}", IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE); } // Assume the type path has a last segment. @@ -202,6 +217,49 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt let dependency_types = get_dependency_types(&item_impl); + let get_dependencies = dependency_types.iter().fold( + Vec::<ExprMethodCall>::new(), + |mut acc, dep_type| { + match dep_type { + Type::TraitObject(dep_type_trait) => { + acc.push( + parse_str( + format!( + "di_container.get::<{}>()", + dep_type_trait.to_token_stream() + ) + .as_str(), + ) + .unwrap(), + ); + } + Type::Path(dep_type_path) => { + let dep_type_path_str = path_to_string(&dep_type_path.path); + + let get_method_name = if dep_type_path_str.ends_with("Factory") { + "get_factory" + } else { + "get" + }; + + acc.push( + parse_str( + format!( + "di_container.{}::<{}>()", + get_method_name, dep_type_path_str + ) + .as_str(), + ) + .unwrap(), + ); + } + &_ => {} + } + + acc + }, + ); + quote! { #item_impl @@ -213,13 +271,15 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt use error_stack::ResultExt; return Ok(Box::new(Self::new( - #(di_container.get::<#dependency_types>() + #(#get_dependencies .change_context(syrette::errors::injectable::ResolveError) - .attach_printable(format!( - "Unable to resolve a dependency of {}", - std::any::type_name::<#self_type_path>() - ))?, - )* + .attach_printable( + format!( + "Unable to resolve a dependency of {}", + std::any::type_name::<#self_type_path>() + ) + )? + ),* ))); } } @@ -229,6 +289,73 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt .into() } +#[proc_macro_attribute] +pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream +{ + let type_alias: ItemType = parse(type_alias_stream).unwrap(); + + let aliased_trait = match &type_alias.ty.as_ref() { + Type::TraitObject(alias_type) => alias_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), + }; + + if aliased_trait.bounds.len() != 1 { + panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); + } + + let type_bound = aliased_trait.bounds.first().unwrap(); + + let trait_bound = match type_bound { + TypeParamBound::Trait(trait_bound) => trait_bound, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), + }; + + let trait_bound_path = &trait_bound.path; + + if trait_bound_path.segments.is_empty() + || trait_bound_path.segments.last().unwrap().ident != "IFactory" + { + panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); + } + + let factory_path_segment = trait_bound_path.segments.last().unwrap(); + + let factory_path_segment_args = &match &factory_path_segment.arguments { + syn::PathArguments::AngleBracketed(args) => args, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + } + .args; + + let factory_arg_types_type = match &factory_path_segment_args[0] { + GenericArgument::Type(arg_type) => arg_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + }; + + let factory_return_type = match &factory_path_segment_args[1] { + GenericArgument::Type(arg_type) => arg_type, + &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), + }; + + quote! { + #type_alias + + syrette::castable_to!( + syrette::castable_factory::CastableFactory< + #factory_arg_types_type, + #factory_return_type + > => #trait_bound_path + ); + + syrette::castable_to!( + syrette::castable_factory::CastableFactory< + #factory_arg_types_type, + #factory_return_type + > => syrette::castable_factory::AnyFactory + ); + } + .into() +} + #[doc(hidden)] #[proc_macro] pub fn castable_to(input: TokenStream) -> TokenStream |