From 51e8d04c2299e6468213d8ee4f9e15d783094379 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sun, 17 Jul 2022 14:55:41 +0200 Subject: refactor: reorganize and improve macros --- syrette_macros/src/factory_type_alias.rs | 83 ++++++ syrette_macros/src/injectable_impl.rs | 244 ++++++++++++++++++ syrette_macros/src/injectable_macro_args.rs | 17 ++ syrette_macros/src/lib.rs | 385 +++++++--------------------- 4 files changed, 430 insertions(+), 299 deletions(-) create mode 100644 syrette_macros/src/factory_type_alias.rs create mode 100644 syrette_macros/src/injectable_impl.rs create mode 100644 syrette_macros/src/injectable_macro_args.rs diff --git a/syrette_macros/src/factory_type_alias.rs b/syrette_macros/src/factory_type_alias.rs new file mode 100644 index 0000000..82e2315 --- /dev/null +++ b/syrette_macros/src/factory_type_alias.rs @@ -0,0 +1,83 @@ +use syn::parse::{Parse, ParseStream}; +use syn::{GenericArgument, ItemType, Path, Type, TypeParamBound, TypeTuple}; + +pub struct FactoryTypeAlias +{ + pub type_alias: ItemType, + pub factory_interface: Path, + pub arg_types: TypeTuple, + pub return_type: Type, +} + +impl Parse for FactoryTypeAlias +{ + fn parse(input: ParseStream) -> syn::Result + { + let type_alias = match input.parse::() { + Ok(type_alias) => Ok(type_alias), + Err(_) => Err(input.error("Expected a type alias")), + }?; + + let aliased_trait = match &type_alias.ty.as_ref() { + Type::TraitObject(alias_type) => Ok(alias_type), + &_ => Err(input.error("Expected the aliased type to be a trait")), + }?; + + if aliased_trait.bounds.len() != 1 { + return Err(input.error("Expected the aliased trait to have a single bound.")); + } + + let bound_path = &match aliased_trait.bounds.first().unwrap() { + TypeParamBound::Trait(trait_bound) => Ok(trait_bound), + &_ => { + Err(input.error("Expected the bound of the aliased trait to be a trait")) + } + }? + .path; + + if bound_path.segments.is_empty() + || bound_path.segments.last().unwrap().ident != "IFactory" + { + return Err(input + .error("Expected the bound of the aliased trait to be 'dyn IFactory'")); + } + + let angle_bracketed_args = match &bound_path.segments.last().unwrap().arguments { + syn::PathArguments::AngleBracketed(angle_bracketed_args) => { + Ok(angle_bracketed_args) + } + &_ => { + Err(input.error("Expected angle bracketed arguments for 'dyn IFactory'")) + } + }?; + + let arg_types = match &angle_bracketed_args.args[0] { + GenericArgument::Type(arg_types_type) => match arg_types_type { + Type::Tuple(arg_types) => Ok(arg_types), + &_ => Err(input.error(concat!( + "Expected the first angle bracketed argument ", + "of 'dyn IFactory' to be a type tuple" + ))), + }, + &_ => Err(input.error(concat!( + "Expected the first angle bracketed argument ", + "of 'dyn IFactory' to be a type" + ))), + }?; + + let return_type = match &angle_bracketed_args.args[1] { + GenericArgument::Type(arg_type) => Ok(arg_type), + &_ => Err(input.error(concat!( + "Expected the second angle bracketed argument ", + "of 'dyn IFactory' to be a type" + ))), + }?; + + Ok(Self { + type_alias: type_alias.clone(), + factory_interface: bound_path.clone(), + arg_types: arg_types.clone(), + return_type: return_type.clone(), + }) + } +} diff --git a/syrette_macros/src/injectable_impl.rs b/syrette_macros/src/injectable_impl.rs new file mode 100644 index 0000000..e7d1b54 --- /dev/null +++ b/syrette_macros/src/injectable_impl.rs @@ -0,0 +1,244 @@ +use quote::{quote, ToTokens}; +use syn::parse::{Parse, ParseStream}; +use syn::{ + parse_str, punctuated::Punctuated, token::Comma, ExprMethodCall, FnArg, + GenericArgument, Ident, ImplItem, ImplItemMethod, ItemImpl, Path, PathArguments, + Type, TypePath, +}; + +const DI_CONTAINER_VAR_NAME: &str = "di_container"; + +pub struct InjectableImpl +{ + pub dependency_types: Vec, + pub self_type: Type, + pub original_impl: ItemImpl, +} + +impl Parse for InjectableImpl +{ + fn parse(input: ParseStream) -> syn::Result + { + match input.parse::() { + Ok(impl_parsed_input) => { + match Self::_get_dependency_types(&impl_parsed_input) { + Ok(dependency_types) => Ok(Self { + dependency_types, + self_type: impl_parsed_input.self_ty.as_ref().clone(), + original_impl: impl_parsed_input, + }), + Err(error_msg) => Err(input.error(error_msg)), + } + } + Err(_) => Err(input.error("Expected an impl")), + } + } +} + +impl InjectableImpl +{ + pub fn expand(&self) -> proc_macro2::TokenStream + { + let original_impl = &self.original_impl; + let self_type = &self.self_type; + + let di_container_var: Ident = parse_str(DI_CONTAINER_VAR_NAME).unwrap(); + + let get_dependencies = Self::_create_get_dependencies(&self.dependency_types); + + quote! { + #original_impl + + impl syrette::interfaces::injectable::Injectable for #self_type { + fn resolve( + #di_container_var: &syrette::DIContainer + ) -> error_stack::Result< + syrette::ptr::InterfacePtr, + syrette::errors::injectable::ResolveError> + { + use error_stack::ResultExt; + + return Ok(syrette::ptr::InterfacePtr::new(Self::new( + #(#get_dependencies + .change_context(syrette::errors::injectable::ResolveError) + .attach_printable( + format!( + "Unable to resolve a dependency of {}", + std::any::type_name::<#self_type>() + ) + )? + ),* + ))); + } + } + } + } + + fn _create_get_dependencies(dependency_types: &[Type]) -> Vec + { + dependency_types + .iter() + .filter_map(|dep_type| match dep_type { + Type::TraitObject(dep_type_trait) => Some( + parse_str( + format!( + "{}.get::<{}>()", + DI_CONTAINER_VAR_NAME, + dep_type_trait.to_token_stream() + ) + .as_str(), + ) + .unwrap(), + ), + Type::Path(dep_type_path) => { + let dep_type_path_str = Self::_path_to_string(&dep_type_path.path); + + let get_method_name = if dep_type_path_str.ends_with("Factory") { + "get_factory" + } else { + "get" + }; + + Some( + parse_str( + format!( + "{}.{}::<{}>()", + DI_CONTAINER_VAR_NAME, get_method_name, dep_type_path_str + ) + .as_str(), + ) + .unwrap(), + ) + } + &_ => None, + }) + .collect() + } + + fn _find_method_by_name<'impl_lt>( + item_impl: &'impl_lt ItemImpl, + method_name: &'static str, + ) -> Option<&'impl_lt ImplItemMethod> + { + let impl_items = &item_impl.items; + + impl_items + .iter() + .filter_map(|impl_item| match impl_item { + ImplItem::Method(method_item) => Some(method_item), + &_ => None, + }) + .find(|method_item| method_item.sig.ident == method_name) + } + + fn get_has_fn_args_self(fn_args: &Punctuated) -> bool + { + fn_args.iter().any(|arg| match arg { + FnArg::Receiver(_) => true, + &_ => false, + }) + } + + fn _get_fn_arg_type_paths(fn_args: &Punctuated) -> Vec<&TypePath> + { + fn_args + .iter() + .filter_map(|arg| match arg { + FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() { + Type::Path(arg_type_path) => Some(arg_type_path), + Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { + Type::Path(arg_type_path) => Some(arg_type_path), + &_ => None, + }, + &_ => None, + }, + FnArg::Receiver(_receiver_fn_arg) => None, + }) + .collect() + } + + fn _path_to_string(path: &Path) -> String + { + path.segments + .pairs() + .fold(String::new(), |mut acc, segment_pair| { + let segment_ident = &segment_pair.value().ident; + + acc.push_str(segment_ident.to_string().as_str()); + + let opt_colon_two = segment_pair.punct(); + + match opt_colon_two { + Some(colon_two) => { + acc.push_str(colon_two.to_token_stream().to_string().as_str()) + } + None => {} + } + + acc + }) + } + + fn _is_type_path_ptr(type_path: &TypePath) -> bool + { + let arg_type_path_string = Self::_path_to_string(&type_path.path); + + arg_type_path_string == "InterfacePtr" + || arg_type_path_string == "ptr::InterfacePtr" + || arg_type_path_string == "syrrete::ptr::InterfacePtr" + || arg_type_path_string == "FactoryPtr" + || arg_type_path_string == "ptr::FactoryPtr" + || arg_type_path_string == "syrrete::ptr::FactoryPtr" + } + + fn _get_dependency_types(item_impl: &ItemImpl) -> Result, &'static str> + { + let new_method_impl_item = match Self::_find_method_by_name(item_impl, "new") { + Some(method_item) => Ok(method_item), + None => Err("Missing a 'new' method"), + }?; + + let new_method_args = &new_method_impl_item.sig.inputs; + + if Self::get_has_fn_args_self(new_method_args) { + return Err("Unexpected self argument in 'new' method"); + } + + let new_method_arg_type_paths = Self::_get_fn_arg_type_paths(new_method_args); + + if new_method_arg_type_paths + .iter() + .any(|arg_type_path| !Self::_is_type_path_ptr(arg_type_path)) + { + return Err("All argument types in 'new' method must ptr types"); + } + + Ok(new_method_arg_type_paths + .iter() + .filter_map(|arg_type_path| { + // Assume the type path has a last segment. + let last_path_segment = arg_type_path.path.segments.last().unwrap(); + + match &last_path_segment.arguments { + PathArguments::AngleBracketed(angle_bracketed_generic_args) => { + let generic_args = &angle_bracketed_generic_args.args; + + let opt_first_generic_arg = generic_args.first(); + + // Assume a first generic argument exists because InterfacePtr and + // FactoryPtr requires one + let first_generic_arg = opt_first_generic_arg.as_ref().unwrap(); + + match first_generic_arg { + GenericArgument::Type(first_generic_arg_type) => { + Some(first_generic_arg_type.clone()) + } + &_ => None, + } + } + &_ => None, + } + }) + .collect()) + } +} diff --git a/syrette_macros/src/injectable_macro_args.rs b/syrette_macros/src/injectable_macro_args.rs new file mode 100644 index 0000000..4ef4389 --- /dev/null +++ b/syrette_macros/src/injectable_macro_args.rs @@ -0,0 +1,17 @@ +use syn::parse::{Parse, ParseStream}; +use syn::TypePath; + +pub struct InjectableMacroArgs +{ + pub interface: TypePath, +} + +impl Parse for InjectableMacroArgs +{ + fn parse(input: ParseStream) -> syn::Result + { + Ok(Self { + interface: input.parse()?, + }) + } +} diff --git a/syrette_macros/src/lib.rs b/syrette_macros/src/lib.rs index 91a0562..3145b5f 100644 --- a/syrette_macros/src/lib.rs +++ b/syrette_macros/src/lib.rs @@ -1,162 +1,21 @@ use proc_macro::TokenStream; -use quote::{quote, ToTokens}; -use syn::{ - parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma, - AttributeArgs, ExprMethodCall, FnArg, GenericArgument, ImplItem, ItemImpl, ItemType, - Meta, NestedMeta, Path, PathArguments, Type, TypeParamBound, TypePath, -}; +use quote::quote; +use syn::{parse, parse_macro_input}; +mod factory_type_alias; +mod injectable_impl; +mod injectable_macro_args; mod libs; +use factory_type_alias::FactoryTypeAlias; +use injectable_impl::InjectableImpl; +use injectable_macro_args::InjectableMacroArgs; use libs::intertrait_macros::{ args::{Casts, Flag, Targets}, gen_caster::generate_caster, }; -const NO_INTERFACE_ARG_ERR_MESSAGE: &str = - "Expected a argument specifying a interface trait"; - -const INVALID_ARG_ERR_MESSAGE: &str = "Invalid argument passed"; - -const INVALID_ITEM_TYPE_ERR_MESSAGE: &str = - "The attached to item is not a trait implementation"; - -const IMPL_NO_NEW_METHOD_ERR_MESSAGE: &str = - "The attached to trait implementation is missing a new method"; - -const IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE: &str = - "The new method of the attached to trait implementation cannot have a self parameter"; - -const IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE: &str = concat!( - "All parameters of the new method of the attached to trait implementation ", - "must be either syrette::ptr::InterfacePtr or syrrete::ptr::FactoryPtr (for factories)" -); - -const INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE: &str = - "Invalid aliased trait. Must be 'dyn IFactory'"; - -const INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE: &str = - "Invalid arguments for 'dyn IFactory'"; - -fn path_to_string(path: &Path) -> String -{ - return path - .segments - .pairs() - .fold(String::new(), |mut acc, segment_pair| { - let segment_ident = &segment_pair.value().ident; - - acc.push_str(segment_ident.to_string().as_str()); - - let opt_colon_two = segment_pair.punct(); - - match opt_colon_two { - Some(colon_two) => { - acc.push_str(colon_two.to_token_stream().to_string().as_str()) - } - None => {} - } - - acc - }); -} - -fn get_fn_args_has_self(fn_args: &Punctuated) -> bool -{ - return fn_args.iter().any(|arg| match arg { - FnArg::Receiver(_) => true, - &_ => false, - }); -} - -fn get_fn_arg_type_paths(fn_args: &Punctuated) -> Vec -{ - return fn_args.iter().fold(Vec::::new(), |mut acc, arg| { - match arg { - FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() { - Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), - Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { - Type::Path(arg_type_path) => acc.push(arg_type_path.clone()), - &_ => {} - }, - &_ => {} - }, - FnArg::Receiver(_receiver_fn_arg) => {} - } - - acc - }); -} - -fn get_dependency_types(item_impl: &ItemImpl) -> Vec -{ - let impl_items = &item_impl.items; - - let opt_new_method_impl_item = impl_items.iter().find(|item| match item { - ImplItem::Method(method_item) => method_item.sig.ident == "new", - &_ => false, - }); - - let new_method_impl_item = match opt_new_method_impl_item { - Some(item) => match item { - ImplItem::Method(method_item) => method_item, - &_ => panic!("{}", IMPL_NO_NEW_METHOD_ERR_MESSAGE), - }, - None => panic!("{}", IMPL_NO_NEW_METHOD_ERR_MESSAGE), - }; - - let new_method_inputs = &new_method_impl_item.sig.inputs; - - if get_fn_args_has_self(new_method_inputs) { - panic!("{}", IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE) - } - - let new_method_arg_type_paths = get_fn_arg_type_paths(new_method_inputs); - - return new_method_arg_type_paths.iter().fold( - Vec::::new(), - |mut acc, arg_type_path| { - let arg_type_path_string = path_to_string(&arg_type_path.path); - - if arg_type_path_string != "InterfacePtr" - && arg_type_path_string != "ptr::InterfacePtr" - && arg_type_path_string != "syrrete::ptr::InterfacePtr" - && arg_type_path_string != "FactoryPtr" - && arg_type_path_string != "ptr::FactoryPtr" - && arg_type_path_string != "syrrete::ptr::FactoryPtr" - { - panic!("{}", IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE); - } - - // Assume the type path has a last segment. - let last_path_segment = arg_type_path.path.segments.last().unwrap(); - - match &last_path_segment.arguments { - PathArguments::AngleBracketed(angle_bracketed_generic_args) => { - let generic_args = &angle_bracketed_generic_args.args; - - let opt_first_generic_arg = generic_args.first(); - - // Assume a first generic argument exists because InterfacePtr and - // FactoryPtr requires one - let first_generic_arg = opt_first_generic_arg.as_ref().unwrap(); - - match first_generic_arg { - GenericArgument::Type(first_generic_arg_type) => { - acc.push(first_generic_arg_type.clone()); - } - &_ => {} - } - } - &_ => {} - } - - acc - }, - ); -} - -/// Makes a struct injectable. Therefore usable with `DIContainer`. +/// Makes a struct injectable. Thereby usable with `DIContainer`. /// /// # Arguments /// @@ -166,15 +25,28 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec /// ``` /// trait IConfigReader /// { -/// fn read_config() -> Config; +/// fn read_config(&self) -> Config; +/// } +/// +/// struct ConfigReader +/// { +/// _file_reader: InterfacePtr, /// } /// -/// struct ConfigReader {} +/// impl ConfigReader +/// { +/// fn new(file_reader: InterfacePtr) -> Self +/// { +/// Self { +/// _file_reader: file_reader +/// } +/// } +/// } /// /// #[injectable(IConfigReader)] /// impl IConfigReader for ConfigReader /// { -/// fn read_config() -> Config +/// fn read_config(&self) -> Config /// { /// // Stuff here /// } @@ -183,175 +55,90 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec #[proc_macro_attribute] pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenStream { - let args = parse_macro_input!(args_stream as AttributeArgs); + let InjectableMacroArgs { + interface: interface_type_path, + } = parse_macro_input!(args_stream); - if args.is_empty() { - panic!("{}", NO_INTERFACE_ARG_ERR_MESSAGE); - } + let injectable_impl: InjectableImpl = parse(impl_stream).unwrap(); - if args.len() > 1 { - panic!("Only a single argument is expected"); - } - - let interface_path = match &args[0] { - NestedMeta::Meta(arg_meta) => match arg_meta { - Meta::Path(path_arg) => path_arg, - &_ => panic!("{}", INVALID_ARG_ERR_MESSAGE), - }, - &_ => panic!("{}", INVALID_ARG_ERR_MESSAGE), - }; - - let item_impl: ItemImpl = match parse(impl_stream) { - Ok(impl_parsed) => impl_parsed, - Err(_) => { - panic!("{}", INVALID_ITEM_TYPE_ERR_MESSAGE) - } - }; - - let self_type = item_impl.self_ty.as_ref(); - - let self_type_path = match self_type { - Type::Path(path_self_type) => path_self_type.path.clone(), - &_ => parse_str("invalid_type").unwrap(), - }; - - let dependency_types = get_dependency_types(&item_impl); + let expanded_injectable_impl = injectable_impl.expand(); - let get_dependencies = dependency_types.iter().fold( - Vec::::new(), - |mut acc, dep_type| { - match dep_type { - Type::TraitObject(dep_type_trait) => { - acc.push( - parse_str( - format!( - "di_container.get::<{}>()", - dep_type_trait.to_token_stream() - ) - .as_str(), - ) - .unwrap(), - ); - } - Type::Path(dep_type_path) => { - let dep_type_path_str = path_to_string(&dep_type_path.path); - - let get_method_name = if dep_type_path_str.ends_with("Factory") { - "get_factory" - } else { - "get" - }; - - acc.push( - parse_str( - format!( - "di_container.{}::<{}>()", - get_method_name, dep_type_path_str - ) - .as_str(), - ) - .unwrap(), - ); - } - &_ => {} - } - - acc - }, - ); + let self_type = &injectable_impl.self_type; quote! { - #item_impl + #expanded_injectable_impl - impl syrette::interfaces::injectable::Injectable for #self_type_path { - fn resolve( - di_container: &syrette::DIContainer - ) -> error_stack::Result< - syrette::ptr::InterfacePtr, - syrette::errors::injectable::ResolveError> - { - use error_stack::ResultExt; - - return Ok(syrette::ptr::InterfacePtr::new(Self::new( - #(#get_dependencies - .change_context(syrette::errors::injectable::ResolveError) - .attach_printable( - format!( - "Unable to resolve a dependency of {}", - std::any::type_name::<#self_type_path>() - ) - )? - ),* - ))); - } - } - - syrette::castable_to!(#self_type_path => #interface_path); + syrette::castable_to!(#self_type => #interface_type_path); } .into() } +/// Makes a type alias usable as a factory interface. +/// +/// # Examples +/// ``` +/// trait IUser +/// { +/// fn name(&self) -> String; +/// fn age(&self) -> i32; +/// } +/// +/// struct User +/// { +/// _name: String, +/// _age: i32, +/// } +/// +/// impl User +/// { +/// fn new(name: String, age: i32) -> Self +/// { +/// Self { +/// _name: name, +/// _age: age, +/// } +/// } +/// } +/// +/// impl IUser for User +/// { +/// fn name(&self) -> String +/// { +/// self._name +/// } +/// +/// fn age(&self) -> i32 +/// { +/// self._age +/// } +/// } +/// +/// type UserFactory = dyn IFactory<(String, i32), dyn IUser>; +/// ``` #[proc_macro_attribute] pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream { - let type_alias: ItemType = parse(type_alias_stream).unwrap(); - - let aliased_trait = match &type_alias.ty.as_ref() { - Type::TraitObject(alias_type) => alias_type, - &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), - }; - - if aliased_trait.bounds.len() != 1 { - panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); - } - - let type_bound = aliased_trait.bounds.first().unwrap(); - - let trait_bound = match type_bound { - TypeParamBound::Trait(trait_bound) => trait_bound, - &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE), - }; - - let trait_bound_path = &trait_bound.path; - - if trait_bound_path.segments.is_empty() - || trait_bound_path.segments.last().unwrap().ident != "IFactory" - { - panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE); - } - - let factory_path_segment = trait_bound_path.segments.last().unwrap(); - - let factory_path_segment_args = &match &factory_path_segment.arguments { - syn::PathArguments::AngleBracketed(args) => args, - &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), - } - .args; - - let factory_arg_types_type = match &factory_path_segment_args[0] { - GenericArgument::Type(arg_type) => arg_type, - &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), - }; - - let factory_return_type = match &factory_path_segment_args[1] { - GenericArgument::Type(arg_type) => arg_type, - &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE), - }; + let FactoryTypeAlias { + type_alias, + factory_interface, + arg_types, + return_type, + } = parse(type_alias_stream).unwrap(); quote! { #type_alias syrette::castable_to!( syrette::castable_factory::CastableFactory< - #factory_arg_types_type, - #factory_return_type - > => #trait_bound_path + #arg_types, + #return_type + > => #factory_interface ); syrette::castable_to!( syrette::castable_factory::CastableFactory< - #factory_arg_types_type, - #factory_return_type + #arg_types, + #return_type > => syrette::castable_factory::AnyFactory ); } -- cgit v1.2.3-18-g5258