From a3ccc2713bb5315123814cadd6c50275eee38e1c Mon Sep 17 00:00:00 2001 From: HampusM Date: Thu, 3 Aug 2023 15:09:46 +0200 Subject: feat: add constructor name flag to injectable macro --- macros/src/declare_interface_args.rs | 14 ++-- macros/src/factory/declare_default_args.rs | 32 +++++--- macros/src/factory/macro_args.rs | 32 +++++--- macros/src/injectable/dependency.rs | 22 +++--- macros/src/injectable/implementation.rs | 120 +++++++++++++++++------------ macros/src/injectable/macro_args.rs | 47 +++++++---- macros/src/lib.rs | 92 ++++++++++++++-------- macros/src/macro_flag.rs | 113 +++++++++++++++++++++++---- macros/src/util/item_impl.rs | 5 +- 9 files changed, 322 insertions(+), 155 deletions(-) (limited to 'macros/src') diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs index cf0dbce..79004da 100644 --- a/macros/src/declare_interface_args.rs +++ b/macros/src/declare_interface_args.rs @@ -30,7 +30,7 @@ impl Parse for DeclareInterfaceArgs let flags = Punctuated::::parse_terminated(input)?; for flag in &flags { - let flag_name = flag.flag.to_string(); + let flag_name = flag.name().to_string(); if !DECLARE_INTERFACE_FLAGS.contains(&flag_name.as_str()) { return Err(input.error(format!( @@ -41,7 +41,7 @@ impl Parse for DeclareInterfaceArgs } if let Some((dupe_flag, _)) = flags.iter().find_duplicate() { - return Err(input.error(format!("Duplicate flag '{}'", dupe_flag.flag))); + return Err(input.error(format!("Duplicate flag '{}'", dupe_flag.name()))); } flags @@ -64,9 +64,10 @@ mod tests use proc_macro2::Span; use quote::{format_ident, quote}; - use syn::{parse2, LitBool}; + use syn::{parse2, Lit, LitBool}; use super::*; + use crate::macro_flag::MacroFlagValue; use crate::test_utils; #[test] @@ -139,8 +140,11 @@ mod tests assert_eq!( decl_interface_args.flags, Punctuated::from_iter(vec![MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }]) ); diff --git a/macros/src/factory/declare_default_args.rs b/macros/src/factory/declare_default_args.rs index 269ef8f..f93d29d 100644 --- a/macros/src/factory/declare_default_args.rs +++ b/macros/src/factory/declare_default_args.rs @@ -31,12 +31,11 @@ impl Parse for DeclareDefaultFactoryMacroArgs let flags = Punctuated::::parse_terminated(input)?; for flag in &flags { - let flag_str = flag.flag.to_string(); + let name = flag.name().to_string(); - if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + if !FACTORY_MACRO_FLAGS.contains(&name.as_str()) { return Err(input.error(format!( - "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, + "Unknown flag '{name}'. Expected one of [ {} ]", FACTORY_MACRO_FLAGS.join(",") ))); } @@ -44,7 +43,7 @@ impl Parse for DeclareDefaultFactoryMacroArgs let flag_names = flags .iter() - .map(|flag| flag.flag.to_string()) + .map(|flag| flag.name().to_string()) .collect::>(); if let Some((dupe_flag_name, _)) = flag_names.iter().find_duplicate() { @@ -65,6 +64,7 @@ mod tests use syn::token::Dyn; use syn::{ parse2, + Lit, LitBool, Path, PathArguments, @@ -77,6 +77,7 @@ mod tests }; use super::*; + use crate::macro_flag::MacroFlagValue; #[test] fn can_parse_with_interface_only() -> Result<(), Box> @@ -142,8 +143,11 @@ mod tests assert_eq!( dec_def_fac_args.flags, Punctuated::from_iter(vec![MacroFlag { - flag: format_ident!("threadsafe"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("threadsafe"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }]) ); @@ -182,12 +186,18 @@ mod tests dec_def_fac_args.flags, Punctuated::from_iter(vec![ MacroFlag { - flag: format_ident!("threadsafe"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("threadsafe"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }, MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) } ]) ); diff --git a/macros/src/factory/macro_args.rs b/macros/src/factory/macro_args.rs index 8acbdb6..cb2cbc9 100644 --- a/macros/src/factory/macro_args.rs +++ b/macros/src/factory/macro_args.rs @@ -19,12 +19,12 @@ impl Parse for FactoryMacroArgs let flags = Punctuated::::parse_terminated(input)?; for flag in &flags { - let flag_str = flag.flag.to_string(); + let name = flag.name().to_string(); - if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { + if !FACTORY_MACRO_FLAGS.contains(&name.as_str()) { return Err(input.error(format!( "Unknown flag '{}'. Expected one of [ {} ]", - flag_str, + name, FACTORY_MACRO_FLAGS.join(",") ))); } @@ -32,7 +32,7 @@ impl Parse for FactoryMacroArgs let flag_names = flags .iter() - .map(|flag| flag.flag.to_string()) + .map(|flag| flag.name().to_string()) .collect::>(); if let Some((dupe_flag_name, _)) = flag_names.iter().find_duplicate() { @@ -50,9 +50,10 @@ mod tests use proc_macro2::Span; use quote::{format_ident, quote}; - use syn::{parse2, LitBool}; + use syn::{parse2, Lit, LitBool}; use super::*; + use crate::macro_flag::MacroFlagValue; #[test] fn can_parse_with_single_flag() -> Result<(), Box> @@ -66,8 +67,11 @@ mod tests assert_eq!( factory_macro_args.flags, Punctuated::from_iter(vec![MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }]) ); @@ -87,12 +91,18 @@ mod tests factory_macro_args.flags, Punctuated::from_iter(vec![ MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }, MacroFlag { - flag: format_ident!("threadsafe"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("threadsafe"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) } ]) ); diff --git a/macros/src/injectable/dependency.rs b/macros/src/injectable/dependency.rs index 33c4583..85cad58 100644 --- a/macros/src/injectable/dependency.rs +++ b/macros/src/injectable/dependency.rs @@ -6,14 +6,12 @@ use crate::injectable::named_attr_input::NamedAttrInput; use crate::util::error::diagnostic_error_enum; use crate::util::syn_path::SynPathExt; -/// Interface for a representation of a dependency of a injectable type. -/// -/// Found as a argument in the 'new' method of the type. +/// Interface for a dependency of a `Injectable`. #[cfg_attr(test, mockall::automock)] pub trait IDependency: Sized { - /// Build a new `Dependency` from a argument in a 'new' method. - fn build(new_method_arg: &FnArg) -> Result; + /// Build a new `Dependency` from a argument in a constructor method. + fn build(ctor_method_arg: &FnArg) -> Result; /// Returns the interface type. fn get_interface(&self) -> &Type; @@ -27,7 +25,7 @@ pub trait IDependency: Sized /// Representation of a dependency of a injectable type. /// -/// Found as a argument in the 'new' method of the type. +/// Found as a argument in the constructor method of a `Injectable`. #[derive(Debug, PartialEq, Eq)] pub struct Dependency { @@ -38,16 +36,16 @@ pub struct Dependency impl IDependency for Dependency { - fn build(new_method_arg: &FnArg) -> Result + fn build(ctor_method_arg: &FnArg) -> Result { - let typed_new_method_arg = match new_method_arg { + let typed_ctor_method_arg = match ctor_method_arg { FnArg::Typed(typed_arg) => Ok(typed_arg), FnArg::Receiver(receiver_arg) => Err(DependencyError::UnexpectedSelf { self_token_span: receiver_arg.self_token.span, }), }?; - let dependency_type_path = match typed_new_method_arg.ty.as_ref() { + let dependency_type_path = match typed_ctor_method_arg.ty.as_ref() { Type::Path(arg_type_path) => Ok(arg_type_path), Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() { Type::Path(arg_type_path) => Ok(arg_type_path), @@ -63,7 +61,7 @@ impl IDependency for Dependency let ptr_path_segment = dependency_type_path.path.segments.last().map_or_else( || { Err(DependencyError::MissingType { - arg_span: typed_new_method_arg.span(), + arg_span: typed_ctor_method_arg.span(), }) }, Ok, @@ -88,7 +86,7 @@ impl IDependency for Dependency }) }?; - let arg_attrs = &typed_new_method_arg.attrs; + let arg_attrs = &typed_ctor_method_arg.attrs; let opt_named_attr = arg_attrs.iter().find(|attr| { attr.path.get_ident().map_or_else( @@ -103,7 +101,7 @@ impl IDependency for Dependency if let Some(named_attr_tokens) = opt_named_attr_tokens { Some(parse2::(named_attr_tokens.clone()).map_err( |err| DependencyError::InvalidNamedAttrInput { - arg_span: typed_new_method_arg.span(), + arg_span: typed_ctor_method_arg.span(), err, }, )?) diff --git a/macros/src/injectable/implementation.rs b/macros/src/injectable/implementation.rs index 3d73cd0..0ea623c 100644 --- a/macros/src/injectable/implementation.rs +++ b/macros/src/injectable/implementation.rs @@ -31,13 +31,16 @@ pub struct InjectableImpl pub generics: Generics, pub original_impl: ItemImpl, - new_method: ImplItemMethod, + constructor_method: ImplItemMethod, } impl InjectableImpl { #[cfg(not(tarpaulin_include))] - pub fn parse(input: TokenStream) -> Result + pub fn parse( + input: TokenStream, + constructor: &Ident, + ) -> Result { let mut item_impl = parse2::(input).map_err(|err| { InjectableImplError::NotAImplementation { @@ -53,79 +56,88 @@ impl InjectableImpl let item_impl_span = item_impl.self_ty.span(); - let new_method = find_impl_method_by_name_mut(&mut item_impl, "new").ok_or( - InjectableImplError::MissingNewMethod { - implementation_span: item_impl_span, - }, - )?; + let constructor_method = + find_impl_method_by_name_mut(&mut item_impl, constructor).ok_or( + InjectableImplError::MissingConstructorMethod { + constructor: constructor.clone(), + implementation_span: item_impl_span, + }, + )?; - let dependencies = Self::build_dependencies(new_method).map_err(|err| { - InjectableImplError::ContainsAInvalidDependency { - implementation_span: item_impl_span, - err, - } - })?; + let dependencies = + Self::build_dependencies(constructor_method).map_err(|err| { + InjectableImplError::ContainsAInvalidDependency { + implementation_span: item_impl_span, + err, + } + })?; - Self::remove_method_argument_attrs(new_method); + Self::remove_method_argument_attrs(constructor_method); - let new_method = new_method.clone(); + let constructor_method = constructor_method.clone(); Ok(Self { dependencies, self_type: item_impl.self_ty.as_ref().clone(), generics: item_impl.generics.clone(), original_impl: item_impl, - new_method, + constructor_method, }) } pub fn validate(&self) -> Result<(), InjectableImplError> { - if matches!(self.new_method.sig.output, ReturnType::Default) { - return Err(InjectableImplError::InvalidNewMethodReturnType { - new_method_output_span: self.new_method.sig.output.span(), + if matches!(self.constructor_method.sig.output, ReturnType::Default) { + return Err(InjectableImplError::InvalidConstructorMethodReturnType { + ctor_method_output_span: self.constructor_method.sig.output.span(), expected: "Self".to_string(), found: "()".to_string(), }); } - if let ReturnType::Type(_, ret_type) = &self.new_method.sig.output { + if let ReturnType::Type(_, ret_type) = &self.constructor_method.sig.output { if let Type::Path(path_type) = ret_type.as_ref() { if path_type .path .get_ident() .map_or_else(|| true, |ident| *ident != "Self") { - return Err(InjectableImplError::InvalidNewMethodReturnType { - new_method_output_span: self.new_method.sig.output.span(), - expected: "Self".to_string(), - found: ret_type.to_token_stream().to_string(), - }); + return Err( + InjectableImplError::InvalidConstructorMethodReturnType { + ctor_method_output_span: self + .constructor_method + .sig + .output + .span(), + expected: "Self".to_string(), + found: ret_type.to_token_stream().to_string(), + }, + ); } } else { - return Err(InjectableImplError::InvalidNewMethodReturnType { - new_method_output_span: self.new_method.sig.output.span(), + return Err(InjectableImplError::InvalidConstructorMethodReturnType { + ctor_method_output_span: self.constructor_method.sig.output.span(), expected: "Self".to_string(), found: ret_type.to_token_stream().to_string(), }); } } - if let Some(unsafety) = self.new_method.sig.unsafety { - return Err(InjectableImplError::NewMethodUnsafe { + if let Some(unsafety) = self.constructor_method.sig.unsafety { + return Err(InjectableImplError::ConstructorMethodUnsafe { unsafety_span: unsafety.span, }); } - if let Some(asyncness) = self.new_method.sig.asyncness { - return Err(InjectableImplError::NewMethodAsync { + if let Some(asyncness) = self.constructor_method.sig.asyncness { + return Err(InjectableImplError::ConstructorMethodAsync { asyncness_span: asyncness.span, }); } - if !self.new_method.sig.generics.params.is_empty() { - return Err(InjectableImplError::NewMethodGeneric { - generics_span: self.new_method.sig.generics.span(), + if !self.constructor_method.sig.generics.params.is_empty() { + return Err(InjectableImplError::ConstructorMethodGeneric { + generics_span: self.constructor_method.sig.generics.span(), }); } Ok(()) @@ -209,6 +221,7 @@ impl InjectableImpl { let generics = &self.generics; let self_type = &self.self_type; + let constructor = &self.constructor_method.sig.ident; quote! { #maybe_doc_hidden @@ -243,7 +256,7 @@ impl InjectableImpl #maybe_prevent_circular_deps - Ok(syrette::ptr::TransientPtr::new(Self::new( + Ok(syrette::ptr::TransientPtr::new(Self::#constructor( #(#get_dep_method_calls),* ))) }) @@ -264,6 +277,7 @@ impl InjectableImpl { let generics = &self.generics; let self_type = &self.self_type; + let constructor = &self.constructor_method.sig.ident; quote! { #maybe_doc_hidden @@ -290,7 +304,7 @@ impl InjectableImpl #maybe_prevent_circular_deps - return Ok(syrette::ptr::TransientPtr::new(Self::new( + return Ok(syrette::ptr::TransientPtr::new(Self::#constructor( #(#get_dep_method_calls),* ))); } @@ -444,13 +458,13 @@ impl InjectableImpl } fn build_dependencies( - new_method: &ImplItemMethod, + ctor_method: &ImplItemMethod, ) -> Result, DependencyError> { - let new_method_args = &new_method.sig.inputs; + let ctor_method_args = &ctor_method.sig.inputs; let dependencies_result: Result, _> = - new_method_args.iter().map(Dep::build).collect(); + ctor_method_args.iter().map(Dep::build).collect(); let deps = dependencies_result?; @@ -515,41 +529,45 @@ pub enum InjectableImplError trait_path_span: Span }, - #[error("Missing a 'new' method"), span = implementation_span] + #[ + error("No constructor method '{constructor}' found in impl"), + span = implementation_span + ] #[note("Required by the 'injectable' attribute macro")] - MissingNewMethod { + MissingConstructorMethod { + constructor: Ident, implementation_span: Span }, #[ error(concat!( - "Invalid 'new' method return type. Expected it to be '{}'. ", + "Invalid constructor method return type. Expected it to be '{}'. ", "Found '{}'" ), expected, found), - span = new_method_output_span + span = ctor_method_output_span ] - InvalidNewMethodReturnType + InvalidConstructorMethodReturnType { - new_method_output_span: Span, + ctor_method_output_span: Span, expected: String, found: String }, - #[error("'new' method is not allowed to be unsafe"), span = unsafety_span] + #[error("Constructor method is not allowed to be unsafe"), span = unsafety_span] #[note("Required by the 'injectable' attribute macro")] - NewMethodUnsafe { + ConstructorMethodUnsafe { unsafety_span: Span }, - #[error("'new' method is not allowed to be async"), span = asyncness_span] + #[error("Constructor method is not allowed to be async"), span = asyncness_span] #[note("Required by the 'injectable' attribute macro")] - NewMethodAsync { + ConstructorMethodAsync { asyncness_span: Span }, - #[error("'new' method is not allowed to have generics"), span = generics_span] + #[error("Constructor method is not allowed to have generics"), span = generics_span] #[note("Required by the 'injectable' attribute macro")] - NewMethodGeneric { + ConstructorMethodGeneric { generics_span: Span }, diff --git a/macros/src/injectable/macro_args.rs b/macros/src/injectable/macro_args.rs index 6964352..ee398fc 100644 --- a/macros/src/injectable/macro_args.rs +++ b/macros/src/injectable/macro_args.rs @@ -7,8 +7,12 @@ use crate::macro_flag::MacroFlag; use crate::util::error::diagnostic_error_enum; use crate::util::iterator_ext::IteratorExt; -pub const INJECTABLE_MACRO_FLAGS: &[&str] = - &["no_doc_hidden", "async", "no_declare_concrete_interface"]; +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &[ + "no_doc_hidden", + "async", + "no_declare_concrete_interface", + "constructor", +]; pub struct InjectableMacroArgs { @@ -21,9 +25,9 @@ impl InjectableMacroArgs pub fn check_flags(&self) -> Result<(), InjectableMacroArgsError> { for flag in &self.flags { - if !INJECTABLE_MACRO_FLAGS.contains(&flag.flag.to_string().as_str()) { + if !INJECTABLE_MACRO_FLAGS.contains(&flag.name().to_string().as_str()) { return Err(InjectableMacroArgsError::UnknownFlag { - flag_ident: flag.flag.clone(), + flag_ident: flag.name().clone(), }); } } @@ -32,8 +36,8 @@ impl InjectableMacroArgs self.flags.iter().find_duplicate() { return Err(InjectableMacroArgsError::DuplicateFlag { - first_flag_ident: dupe_flag_first.flag.clone(), - last_flag_span: dupe_flag_second.flag.span(), + first_flag_ident: dupe_flag_first.name().clone(), + last_flag_span: dupe_flag_second.name().span(), }); } @@ -111,9 +115,10 @@ mod tests use proc_macro2::Span; use quote::{format_ident, quote}; - use syn::{parse2, LitBool}; + use syn::{parse2, Lit, LitBool}; use super::*; + use crate::macro_flag::MacroFlagValue; use crate::test_utils; #[test] @@ -174,12 +179,18 @@ mod tests injectable_macro_args.flags, Punctuated::from_iter([ MacroFlag { - flag: format_ident!("no_doc_hidden"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("no_doc_hidden"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) }, MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) } ]) ); @@ -202,12 +213,18 @@ mod tests injectable_macro_args.flags, Punctuated::from_iter([ MacroFlag { - flag: format_ident!("async"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("async"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) }, MacroFlag { - flag: format_ident!("no_declare_concrete_interface"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("no_declare_concrete_interface"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) } ]) ); diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 04720c1..5f5b0b6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -44,14 +44,29 @@ const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION"); /// /// # Arguments /// * (Optional) A interface trait the struct implements. -/// * (Zero or more) Flags. Like `a = true, b = false` +/// * (Zero or more) Comma separated flags. Each flag being formatted `name=value`. /// /// # Flags -/// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from -/// documentation. -/// - `no_declare_concrete_interface` - Disable declaring the concrete type as the -/// interface when no interface trait argument is given. -/// - `async` - Mark as async. +/// #### `no_doc_hidden` +/// **Value:** boolean literal
+/// **Default:** `false`
+/// Don't hide the impl of the [`Injectable`] trait from documentation. +/// +/// #### `no_declare_concrete_interface` +/// **Value:** boolean literal
+/// **Default:** `false`
+/// Disable declaring the concrete type as the interface when no interface trait argument +/// is given. +/// +/// #### `async` +/// **Value:** boolean literal
+/// **Default:** `false`
+/// Mark as async. +/// +/// #### `constructor` +/// **Value:** identifier
+/// **Default:** `new`
+/// Constructor method name. /// /// # Panics /// If the attributed item is not a impl. @@ -132,6 +147,8 @@ const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION"); #[proc_macro_attribute] pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenStream { + use quote::format_ident; + let input_stream: proc_macro2::TokenStream = input_stream.into(); set_dummy(input_stream.clone()); @@ -143,28 +160,39 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS let no_doc_hidden = args .flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "no_doc_hidden") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); let no_declare_concrete_interface = args .flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "no_declare_concrete_interface") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "no_declare_concrete_interface") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); + + let constructor = args + .flags + .iter() + .find(|flag| flag.name() == "constructor") + .map_or(Ok(format_ident!("new")), MacroFlag::get_ident) + .unwrap_or_abort(); let is_async_flag = args .flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "async") + .find(|flag| flag.name() == "async") .cloned() .unwrap_or_else(|| MacroFlag::new_off("async")); + let is_async = is_async_flag.get_bool().unwrap_or_abort(); + #[cfg(not(feature = "async"))] - if is_async_flag.is_on() { + if is_async { use proc_macro_error::abort; abort!( - is_async_flag.flag.span(), + is_async_flag.name().span(), "The 'async' Cargo feature must be enabled to use this flag"; suggestion = "In your Cargo.toml: syrette = {{ version = \"{}\", features = [\"async\"] }}", PACKAGE_VERSION @@ -172,9 +200,9 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS } let injectable_impl = - InjectableImpl::::parse(input_stream).unwrap_or_abort(); + InjectableImpl::::parse(input_stream, &constructor).unwrap_or_abort(); - set_dummy(if is_async_flag.is_on() { + set_dummy(if is_async { injectable_impl.expand_dummy_async_impl() } else { injectable_impl.expand_dummy_blocking_impl() @@ -182,8 +210,7 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS injectable_impl.validate().unwrap_or_abort(); - let expanded_injectable_impl = - injectable_impl.expand(no_doc_hidden, is_async_flag.is_on()); + let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async); let self_type = &injectable_impl.self_type; @@ -196,7 +223,7 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS }); let maybe_decl_interface = if let Some(interface) = opt_interface { - let async_flag = if is_async_flag.is_on() { + let async_flag = if is_async { quote! {, async = true} } else { quote! {} @@ -277,13 +304,15 @@ pub fn factory(args_stream: TokenStream, input_stream: TokenStream) -> TokenStre let mut is_threadsafe = flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "threadsafe") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "threadsafe") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); let is_async = flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "async") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "async") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); if is_async { is_threadsafe = true; @@ -377,13 +406,15 @@ pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream let mut is_threadsafe = flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "threadsafe") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "threadsafe") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); let is_async = flags .iter() - .find(|flag| flag.flag.to_string().as_str() == "async") - .map_or(false, |flag| flag.is_on.value); + .find(|flag| flag.name() == "async") + .map_or(Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); if is_async { is_threadsafe = true; @@ -446,12 +477,11 @@ pub fn declare_interface(input: TokenStream) -> TokenStream flags, } = parse(input).unwrap_or_abort(); - let opt_async_flag = flags - .iter() - .find(|flag| flag.flag.to_string().as_str() == "async"); + let opt_async_flag = flags.iter().find(|flag| flag.name() == "async"); - let is_async = - opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value); + let is_async = opt_async_flag + .map_or_else(|| Ok(false), MacroFlag::get_bool) + .unwrap_or_abort(); let interface_type = if interface == implementation { Type::Path(interface) diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs index f0e3a70..ba71cc2 100644 --- a/macros/src/macro_flag.rs +++ b/macros/src/macro_flag.rs @@ -2,13 +2,15 @@ use std::hash::Hash; use proc_macro2::Span; use syn::parse::{Parse, ParseStream}; -use syn::{Ident, LitBool, Token}; +use syn::{Ident, Lit, LitBool, Token}; -#[derive(Debug, Eq, Clone)] +use crate::util::error::diagnostic_error_enum; + +#[derive(Debug, Clone)] pub struct MacroFlag { - pub flag: Ident, - pub is_on: LitBool, + pub name: Ident, + pub value: MacroFlagValue, } impl MacroFlag @@ -16,14 +18,41 @@ impl MacroFlag pub fn new_off(flag: &str) -> Self { Self { - flag: Ident::new(flag, Span::call_site()), - is_on: LitBool::new(false, Span::call_site()), + name: Ident::new(flag, Span::call_site()), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site(), + ))), } } - pub fn is_on(&self) -> bool + pub fn name(&self) -> &Ident { - self.is_on.value + &self.name + } + + pub fn get_bool(&self) -> Result + { + if let MacroFlagValue::Literal(Lit::Bool(lit_bool)) = &self.value { + return Ok(lit_bool.value); + } + + Err(MacroFlagError::UnexpectedValueKind { + expected: "boolean literal", + value_span: self.value.span(), + }) + } + + pub fn get_ident(&self) -> Result + { + if let MacroFlagValue::Identifier(ident) = &self.value { + return Ok(ident.clone()); + } + + Err(MacroFlagError::UnexpectedValueKind { + expected: "identifier", + value_span: self.value.span(), + }) } } @@ -31,13 +60,13 @@ impl Parse for MacroFlag { fn parse(input: ParseStream) -> syn::Result { - let flag = input.parse::()?; + let name = input.parse::()?; input.parse::()?; - let is_on: LitBool = input.parse()?; + let value: MacroFlagValue = input.parse()?; - Ok(Self { flag, is_on }) + Ok(Self { name, value }) } } @@ -45,15 +74,59 @@ impl PartialEq for MacroFlag { fn eq(&self, other: &Self) -> bool { - self.flag == other.flag + self.name == other.name } } +impl Eq for MacroFlag {} + impl Hash for MacroFlag { fn hash(&self, state: &mut H) { - self.flag.hash(state); + self.name.hash(state); + } +} + +diagnostic_error_enum! { +pub enum MacroFlagError { + #[error("Expected a {expected}"), span = value_span] + UnexpectedValueKind { + expected: &'static str, + value_span: Span + }, +} +} + +#[derive(Debug, Clone)] +pub enum MacroFlagValue +{ + Literal(Lit), + Identifier(Ident), +} + +impl MacroFlagValue +{ + fn span(&self) -> Span + { + match self { + Self::Literal(lit) => lit.span(), + Self::Identifier(ident) => ident.span(), + } + } +} + +impl Parse for MacroFlagValue +{ + fn parse(input: ParseStream) -> syn::Result + { + if let Ok(lit) = input.parse::() { + return Ok(Self::Literal(lit)); + }; + + input.parse::().map(Self::Identifier).map_err(|err| { + syn::Error::new(err.span(), "Expected a literal or a identifier") + }) } } @@ -76,8 +149,11 @@ mod tests more = true })?, MacroFlag { - flag: format_ident!("more"), - is_on: LitBool::new(true, Span::call_site()) + name: format_ident!("more"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + true, + Span::call_site() + ))) } ); @@ -86,8 +162,11 @@ mod tests do_something = false })?, MacroFlag { - flag: format_ident!("do_something"), - is_on: LitBool::new(false, Span::call_site()) + name: format_ident!("do_something"), + value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( + false, + Span::call_site() + ))) } ); diff --git a/macros/src/util/item_impl.rs b/macros/src/util/item_impl.rs index 4bd7492..621f6be 100644 --- a/macros/src/util/item_impl.rs +++ b/macros/src/util/item_impl.rs @@ -1,15 +1,16 @@ +use proc_macro2::Ident; use syn::{ImplItem, ImplItemMethod, ItemImpl}; pub fn find_impl_method_by_name_mut<'item_impl>( item_impl: &'item_impl mut ItemImpl, - method_name: &'static str, + method_name: &Ident, ) -> Option<&'item_impl mut ImplItemMethod> { let impl_items = &mut item_impl.items; impl_items.iter_mut().find_map(|impl_item| match impl_item { ImplItem::Method(method_item) => { - if method_item.sig.ident == method_name { + if &method_item.sig.ident == method_name { Some(method_item) } else { None -- cgit v1.2.3-18-g5258