diff options
| author | HampusM <hampus@hampusmat.com> | 2023-08-03 15:09:46 +0200 | 
|---|---|---|
| committer | HampusM <hampus@hampusmat.com> | 2023-08-03 15:15:29 +0200 | 
| commit | a3ccc2713bb5315123814cadd6c50275eee38e1c (patch) | |
| tree | 22d50e174f3181bbddfd50840408e85e5a4adaac /macros | |
| parent | 14f1fc1837675e1771e220f848b46213462ae804 (diff) | |
feat: add constructor name flag to injectable macro
Diffstat (limited to 'macros')
| -rw-r--r-- | macros/src/declare_interface_args.rs | 14 | ||||
| -rw-r--r-- | macros/src/factory/declare_default_args.rs | 32 | ||||
| -rw-r--r-- | macros/src/factory/macro_args.rs | 32 | ||||
| -rw-r--r-- | macros/src/injectable/dependency.rs | 22 | ||||
| -rw-r--r-- | macros/src/injectable/implementation.rs | 120 | ||||
| -rw-r--r-- | macros/src/injectable/macro_args.rs | 47 | ||||
| -rw-r--r-- | macros/src/lib.rs | 92 | ||||
| -rw-r--r-- | macros/src/macro_flag.rs | 113 | ||||
| -rw-r--r-- | macros/src/util/item_impl.rs | 5 | 
9 files changed, 322 insertions, 155 deletions
| diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs index cf0dbce..79004da 100644 --- a/macros/src/declare_interface_args.rs +++ b/macros/src/declare_interface_args.rs @@ -30,7 +30,7 @@ impl Parse for DeclareInterfaceArgs              let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;              for flag in &flags { -                let flag_name = flag.flag.to_string(); +                let flag_name = flag.name().to_string();                  if !DECLARE_INTERFACE_FLAGS.contains(&flag_name.as_str()) {                      return Err(input.error(format!( @@ -41,7 +41,7 @@ impl Parse for DeclareInterfaceArgs              }              if let Some((dupe_flag, _)) = flags.iter().find_duplicate() { -                return Err(input.error(format!("Duplicate flag '{}'", dupe_flag.flag))); +                return Err(input.error(format!("Duplicate flag '{}'", dupe_flag.name())));              }              flags @@ -64,9 +64,10 @@ mod tests      use proc_macro2::Span;      use quote::{format_ident, quote}; -    use syn::{parse2, LitBool}; +    use syn::{parse2, Lit, LitBool};      use super::*; +    use crate::macro_flag::MacroFlagValue;      use crate::test_utils;      #[test] @@ -139,8 +140,11 @@ mod tests          assert_eq!(              decl_interface_args.flags,              Punctuated::from_iter(vec![MacroFlag { -                flag: format_ident!("async"), -                is_on: LitBool::new(true, Span::call_site()) +                name: format_ident!("async"), +                value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                    true, +                    Span::call_site() +                )))              }])          ); diff --git a/macros/src/factory/declare_default_args.rs b/macros/src/factory/declare_default_args.rs index 269ef8f..f93d29d 100644 --- a/macros/src/factory/declare_default_args.rs +++ b/macros/src/factory/declare_default_args.rs @@ -31,12 +31,11 @@ impl Parse for DeclareDefaultFactoryMacroArgs          let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;          for flag in &flags { -            let flag_str = flag.flag.to_string(); +            let name = flag.name().to_string(); -            if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { +            if !FACTORY_MACRO_FLAGS.contains(&name.as_str()) {                  return Err(input.error(format!( -                    "Unknown flag '{}'. Expected one of [ {} ]", -                    flag_str, +                    "Unknown flag '{name}'. Expected one of [ {} ]",                      FACTORY_MACRO_FLAGS.join(",")                  )));              } @@ -44,7 +43,7 @@ impl Parse for DeclareDefaultFactoryMacroArgs          let flag_names = flags              .iter() -            .map(|flag| flag.flag.to_string()) +            .map(|flag| flag.name().to_string())              .collect::<Vec<_>>();          if let Some((dupe_flag_name, _)) = flag_names.iter().find_duplicate() { @@ -65,6 +64,7 @@ mod tests      use syn::token::Dyn;      use syn::{          parse2, +        Lit,          LitBool,          Path,          PathArguments, @@ -77,6 +77,7 @@ mod tests      };      use super::*; +    use crate::macro_flag::MacroFlagValue;      #[test]      fn can_parse_with_interface_only() -> Result<(), Box<dyn Error>> @@ -142,8 +143,11 @@ mod tests          assert_eq!(              dec_def_fac_args.flags,              Punctuated::from_iter(vec![MacroFlag { -                flag: format_ident!("threadsafe"), -                is_on: LitBool::new(true, Span::call_site()) +                name: format_ident!("threadsafe"), +                value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                    true, +                    Span::call_site() +                )))              }])          ); @@ -182,12 +186,18 @@ mod tests              dec_def_fac_args.flags,              Punctuated::from_iter(vec![                  MacroFlag { -                    flag: format_ident!("threadsafe"), -                    is_on: LitBool::new(true, Span::call_site()) +                    name: format_ident!("threadsafe"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        true, +                        Span::call_site() +                    )))                  },                  MacroFlag { -                    flag: format_ident!("async"), -                    is_on: LitBool::new(false, Span::call_site()) +                    name: format_ident!("async"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        false, +                        Span::call_site() +                    )))                  }              ])          ); diff --git a/macros/src/factory/macro_args.rs b/macros/src/factory/macro_args.rs index 8acbdb6..cb2cbc9 100644 --- a/macros/src/factory/macro_args.rs +++ b/macros/src/factory/macro_args.rs @@ -19,12 +19,12 @@ impl Parse for FactoryMacroArgs          let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;          for flag in &flags { -            let flag_str = flag.flag.to_string(); +            let name = flag.name().to_string(); -            if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) { +            if !FACTORY_MACRO_FLAGS.contains(&name.as_str()) {                  return Err(input.error(format!(                      "Unknown flag '{}'. Expected one of [ {} ]", -                    flag_str, +                    name,                      FACTORY_MACRO_FLAGS.join(",")                  )));              } @@ -32,7 +32,7 @@ impl Parse for FactoryMacroArgs          let flag_names = flags              .iter() -            .map(|flag| flag.flag.to_string()) +            .map(|flag| flag.name().to_string())              .collect::<Vec<_>>();          if let Some((dupe_flag_name, _)) = flag_names.iter().find_duplicate() { @@ -50,9 +50,10 @@ mod tests      use proc_macro2::Span;      use quote::{format_ident, quote}; -    use syn::{parse2, LitBool}; +    use syn::{parse2, Lit, LitBool};      use super::*; +    use crate::macro_flag::MacroFlagValue;      #[test]      fn can_parse_with_single_flag() -> Result<(), Box<dyn Error>> @@ -66,8 +67,11 @@ mod tests          assert_eq!(              factory_macro_args.flags,              Punctuated::from_iter(vec![MacroFlag { -                flag: format_ident!("async"), -                is_on: LitBool::new(true, Span::call_site()) +                name: format_ident!("async"), +                value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                    true, +                    Span::call_site() +                )))              }])          ); @@ -87,12 +91,18 @@ mod tests              factory_macro_args.flags,              Punctuated::from_iter(vec![                  MacroFlag { -                    flag: format_ident!("async"), -                    is_on: LitBool::new(true, Span::call_site()) +                    name: format_ident!("async"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        true, +                        Span::call_site() +                    )))                  },                  MacroFlag { -                    flag: format_ident!("threadsafe"), -                    is_on: LitBool::new(false, Span::call_site()) +                    name: format_ident!("threadsafe"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        false, +                        Span::call_site() +                    )))                  }              ])          ); diff --git a/macros/src/injectable/dependency.rs b/macros/src/injectable/dependency.rs index 33c4583..85cad58 100644 --- a/macros/src/injectable/dependency.rs +++ b/macros/src/injectable/dependency.rs @@ -6,14 +6,12 @@ use crate::injectable::named_attr_input::NamedAttrInput;  use crate::util::error::diagnostic_error_enum;  use crate::util::syn_path::SynPathExt; -/// Interface for a representation of a dependency of a injectable type. -/// -/// Found as a argument in the 'new' method of the type. +/// Interface for a dependency of a `Injectable`.  #[cfg_attr(test, mockall::automock)]  pub trait IDependency: Sized  { -    /// Build a new `Dependency` from a argument in a 'new' method. -    fn build(new_method_arg: &FnArg) -> Result<Self, DependencyError>; +    /// Build a new `Dependency` from a argument in a constructor method. +    fn build(ctor_method_arg: &FnArg) -> Result<Self, DependencyError>;      /// Returns the interface type.      fn get_interface(&self) -> &Type; @@ -27,7 +25,7 @@ pub trait IDependency: Sized  /// Representation of a dependency of a injectable type.  /// -/// Found as a argument in the 'new' method of the type. +/// Found as a argument in the constructor method of a `Injectable`.  #[derive(Debug, PartialEq, Eq)]  pub struct Dependency  { @@ -38,16 +36,16 @@ pub struct Dependency  impl IDependency for Dependency  { -    fn build(new_method_arg: &FnArg) -> Result<Self, DependencyError> +    fn build(ctor_method_arg: &FnArg) -> Result<Self, DependencyError>      { -        let typed_new_method_arg = match new_method_arg { +        let typed_ctor_method_arg = match ctor_method_arg {              FnArg::Typed(typed_arg) => Ok(typed_arg),              FnArg::Receiver(receiver_arg) => Err(DependencyError::UnexpectedSelf {                  self_token_span: receiver_arg.self_token.span,              }),          }?; -        let dependency_type_path = match typed_new_method_arg.ty.as_ref() { +        let dependency_type_path = match typed_ctor_method_arg.ty.as_ref() {              Type::Path(arg_type_path) => Ok(arg_type_path),              Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() {                  Type::Path(arg_type_path) => Ok(arg_type_path), @@ -63,7 +61,7 @@ impl IDependency for Dependency          let ptr_path_segment = dependency_type_path.path.segments.last().map_or_else(              || {                  Err(DependencyError::MissingType { -                    arg_span: typed_new_method_arg.span(), +                    arg_span: typed_ctor_method_arg.span(),                  })              },              Ok, @@ -88,7 +86,7 @@ impl IDependency for Dependency                  })              }?; -        let arg_attrs = &typed_new_method_arg.attrs; +        let arg_attrs = &typed_ctor_method_arg.attrs;          let opt_named_attr = arg_attrs.iter().find(|attr| {              attr.path.get_ident().map_or_else( @@ -103,7 +101,7 @@ impl IDependency for Dependency              if let Some(named_attr_tokens) = opt_named_attr_tokens {                  Some(parse2::<NamedAttrInput>(named_attr_tokens.clone()).map_err(                      |err| DependencyError::InvalidNamedAttrInput { -                        arg_span: typed_new_method_arg.span(), +                        arg_span: typed_ctor_method_arg.span(),                          err,                      },                  )?) diff --git a/macros/src/injectable/implementation.rs b/macros/src/injectable/implementation.rs index 3d73cd0..0ea623c 100644 --- a/macros/src/injectable/implementation.rs +++ b/macros/src/injectable/implementation.rs @@ -31,13 +31,16 @@ pub struct InjectableImpl<Dep: IDependency>      pub generics: Generics,      pub original_impl: ItemImpl, -    new_method: ImplItemMethod, +    constructor_method: ImplItemMethod,  }  impl<Dep: IDependency> InjectableImpl<Dep>  {      #[cfg(not(tarpaulin_include))] -    pub fn parse(input: TokenStream) -> Result<Self, InjectableImplError> +    pub fn parse( +        input: TokenStream, +        constructor: &Ident, +    ) -> Result<Self, InjectableImplError>      {          let mut item_impl = parse2::<ItemImpl>(input).map_err(|err| {              InjectableImplError::NotAImplementation { @@ -53,79 +56,88 @@ impl<Dep: IDependency> InjectableImpl<Dep>          let item_impl_span = item_impl.self_ty.span(); -        let new_method = find_impl_method_by_name_mut(&mut item_impl, "new").ok_or( -            InjectableImplError::MissingNewMethod { -                implementation_span: item_impl_span, -            }, -        )?; +        let constructor_method = +            find_impl_method_by_name_mut(&mut item_impl, constructor).ok_or( +                InjectableImplError::MissingConstructorMethod { +                    constructor: constructor.clone(), +                    implementation_span: item_impl_span, +                }, +            )?; -        let dependencies = Self::build_dependencies(new_method).map_err(|err| { -            InjectableImplError::ContainsAInvalidDependency { -                implementation_span: item_impl_span, -                err, -            } -        })?; +        let dependencies = +            Self::build_dependencies(constructor_method).map_err(|err| { +                InjectableImplError::ContainsAInvalidDependency { +                    implementation_span: item_impl_span, +                    err, +                } +            })?; -        Self::remove_method_argument_attrs(new_method); +        Self::remove_method_argument_attrs(constructor_method); -        let new_method = new_method.clone(); +        let constructor_method = constructor_method.clone();          Ok(Self {              dependencies,              self_type: item_impl.self_ty.as_ref().clone(),              generics: item_impl.generics.clone(),              original_impl: item_impl, -            new_method, +            constructor_method,          })      }      pub fn validate(&self) -> Result<(), InjectableImplError>      { -        if matches!(self.new_method.sig.output, ReturnType::Default) { -            return Err(InjectableImplError::InvalidNewMethodReturnType { -                new_method_output_span: self.new_method.sig.output.span(), +        if matches!(self.constructor_method.sig.output, ReturnType::Default) { +            return Err(InjectableImplError::InvalidConstructorMethodReturnType { +                ctor_method_output_span: self.constructor_method.sig.output.span(),                  expected: "Self".to_string(),                  found: "()".to_string(),              });          } -        if let ReturnType::Type(_, ret_type) = &self.new_method.sig.output { +        if let ReturnType::Type(_, ret_type) = &self.constructor_method.sig.output {              if let Type::Path(path_type) = ret_type.as_ref() {                  if path_type                      .path                      .get_ident()                      .map_or_else(|| true, |ident| *ident != "Self")                  { -                    return Err(InjectableImplError::InvalidNewMethodReturnType { -                        new_method_output_span: self.new_method.sig.output.span(), -                        expected: "Self".to_string(), -                        found: ret_type.to_token_stream().to_string(), -                    }); +                    return Err( +                        InjectableImplError::InvalidConstructorMethodReturnType { +                            ctor_method_output_span: self +                                .constructor_method +                                .sig +                                .output +                                .span(), +                            expected: "Self".to_string(), +                            found: ret_type.to_token_stream().to_string(), +                        }, +                    );                  }              } else { -                return Err(InjectableImplError::InvalidNewMethodReturnType { -                    new_method_output_span: self.new_method.sig.output.span(), +                return Err(InjectableImplError::InvalidConstructorMethodReturnType { +                    ctor_method_output_span: self.constructor_method.sig.output.span(),                      expected: "Self".to_string(),                      found: ret_type.to_token_stream().to_string(),                  });              }          } -        if let Some(unsafety) = self.new_method.sig.unsafety { -            return Err(InjectableImplError::NewMethodUnsafe { +        if let Some(unsafety) = self.constructor_method.sig.unsafety { +            return Err(InjectableImplError::ConstructorMethodUnsafe {                  unsafety_span: unsafety.span,              });          } -        if let Some(asyncness) = self.new_method.sig.asyncness { -            return Err(InjectableImplError::NewMethodAsync { +        if let Some(asyncness) = self.constructor_method.sig.asyncness { +            return Err(InjectableImplError::ConstructorMethodAsync {                  asyncness_span: asyncness.span,              });          } -        if !self.new_method.sig.generics.params.is_empty() { -            return Err(InjectableImplError::NewMethodGeneric { -                generics_span: self.new_method.sig.generics.span(), +        if !self.constructor_method.sig.generics.params.is_empty() { +            return Err(InjectableImplError::ConstructorMethodGeneric { +                generics_span: self.constructor_method.sig.generics.span(),              });          }          Ok(()) @@ -209,6 +221,7 @@ impl<Dep: IDependency> InjectableImpl<Dep>      {          let generics = &self.generics;          let self_type = &self.self_type; +        let constructor = &self.constructor_method.sig.ident;          quote! {              #maybe_doc_hidden @@ -243,7 +256,7 @@ impl<Dep: IDependency> InjectableImpl<Dep>                          #maybe_prevent_circular_deps -                        Ok(syrette::ptr::TransientPtr::new(Self::new( +                        Ok(syrette::ptr::TransientPtr::new(Self::#constructor(                              #(#get_dep_method_calls),*                          )))                      }) @@ -264,6 +277,7 @@ impl<Dep: IDependency> InjectableImpl<Dep>      {          let generics = &self.generics;          let self_type = &self.self_type; +        let constructor = &self.constructor_method.sig.ident;          quote! {              #maybe_doc_hidden @@ -290,7 +304,7 @@ impl<Dep: IDependency> InjectableImpl<Dep>                      #maybe_prevent_circular_deps -                    return Ok(syrette::ptr::TransientPtr::new(Self::new( +                    return Ok(syrette::ptr::TransientPtr::new(Self::#constructor(                          #(#get_dep_method_calls),*                      )));                  } @@ -444,13 +458,13 @@ impl<Dep: IDependency> InjectableImpl<Dep>      }      fn build_dependencies( -        new_method: &ImplItemMethod, +        ctor_method: &ImplItemMethod,      ) -> Result<Vec<Dep>, DependencyError>      { -        let new_method_args = &new_method.sig.inputs; +        let ctor_method_args = &ctor_method.sig.inputs;          let dependencies_result: Result<Vec<_>, _> = -            new_method_args.iter().map(Dep::build).collect(); +            ctor_method_args.iter().map(Dep::build).collect();          let deps = dependencies_result?; @@ -515,41 +529,45 @@ pub enum InjectableImplError          trait_path_span: Span      }, -    #[error("Missing a 'new' method"), span = implementation_span] +    #[ +        error("No constructor method '{constructor}' found in impl"), +        span = implementation_span +    ]      #[note("Required by the 'injectable' attribute macro")] -    MissingNewMethod { +    MissingConstructorMethod { +        constructor: Ident,          implementation_span: Span      },      #[          error(concat!( -            "Invalid 'new' method return type. Expected it to be '{}'. ", +            "Invalid constructor method return type. Expected it to be '{}'. ",              "Found '{}'"          ), expected, found), -        span = new_method_output_span +        span = ctor_method_output_span      ] -    InvalidNewMethodReturnType +    InvalidConstructorMethodReturnType      { -        new_method_output_span: Span, +        ctor_method_output_span: Span,          expected: String,          found: String      }, -    #[error("'new' method is not allowed to be unsafe"), span = unsafety_span] +    #[error("Constructor method is not allowed to be unsafe"), span = unsafety_span]      #[note("Required by the 'injectable' attribute macro")] -    NewMethodUnsafe { +    ConstructorMethodUnsafe {          unsafety_span: Span      }, -    #[error("'new' method is not allowed to be async"), span = asyncness_span] +    #[error("Constructor method is not allowed to be async"), span = asyncness_span]      #[note("Required by the 'injectable' attribute macro")] -    NewMethodAsync { +    ConstructorMethodAsync {          asyncness_span: Span      }, -    #[error("'new' method is not allowed to have generics"), span = generics_span] +    #[error("Constructor method is not allowed to have generics"), span = generics_span]      #[note("Required by the 'injectable' attribute macro")] -    NewMethodGeneric { +    ConstructorMethodGeneric {          generics_span: Span      }, diff --git a/macros/src/injectable/macro_args.rs b/macros/src/injectable/macro_args.rs index 6964352..ee398fc 100644 --- a/macros/src/injectable/macro_args.rs +++ b/macros/src/injectable/macro_args.rs @@ -7,8 +7,12 @@ use crate::macro_flag::MacroFlag;  use crate::util::error::diagnostic_error_enum;  use crate::util::iterator_ext::IteratorExt; -pub const INJECTABLE_MACRO_FLAGS: &[&str] = -    &["no_doc_hidden", "async", "no_declare_concrete_interface"]; +pub const INJECTABLE_MACRO_FLAGS: &[&str] = &[ +    "no_doc_hidden", +    "async", +    "no_declare_concrete_interface", +    "constructor", +];  pub struct InjectableMacroArgs  { @@ -21,9 +25,9 @@ impl InjectableMacroArgs      pub fn check_flags(&self) -> Result<(), InjectableMacroArgsError>      {          for flag in &self.flags { -            if !INJECTABLE_MACRO_FLAGS.contains(&flag.flag.to_string().as_str()) { +            if !INJECTABLE_MACRO_FLAGS.contains(&flag.name().to_string().as_str()) {                  return Err(InjectableMacroArgsError::UnknownFlag { -                    flag_ident: flag.flag.clone(), +                    flag_ident: flag.name().clone(),                  });              }          } @@ -32,8 +36,8 @@ impl InjectableMacroArgs              self.flags.iter().find_duplicate()          {              return Err(InjectableMacroArgsError::DuplicateFlag { -                first_flag_ident: dupe_flag_first.flag.clone(), -                last_flag_span: dupe_flag_second.flag.span(), +                first_flag_ident: dupe_flag_first.name().clone(), +                last_flag_span: dupe_flag_second.name().span(),              });          } @@ -111,9 +115,10 @@ mod tests      use proc_macro2::Span;      use quote::{format_ident, quote}; -    use syn::{parse2, LitBool}; +    use syn::{parse2, Lit, LitBool};      use super::*; +    use crate::macro_flag::MacroFlagValue;      use crate::test_utils;      #[test] @@ -174,12 +179,18 @@ mod tests              injectable_macro_args.flags,              Punctuated::from_iter([                  MacroFlag { -                    flag: format_ident!("no_doc_hidden"), -                    is_on: LitBool::new(true, Span::call_site()) +                    name: format_ident!("no_doc_hidden"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        true, +                        Span::call_site() +                    )))                  },                  MacroFlag { -                    flag: format_ident!("async"), -                    is_on: LitBool::new(false, Span::call_site()) +                    name: format_ident!("async"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        false, +                        Span::call_site() +                    )))                  }              ])          ); @@ -202,12 +213,18 @@ mod tests              injectable_macro_args.flags,              Punctuated::from_iter([                  MacroFlag { -                    flag: format_ident!("async"), -                    is_on: LitBool::new(false, Span::call_site()) +                    name: format_ident!("async"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        false, +                        Span::call_site() +                    )))                  },                  MacroFlag { -                    flag: format_ident!("no_declare_concrete_interface"), -                    is_on: LitBool::new(false, Span::call_site()) +                    name: format_ident!("no_declare_concrete_interface"), +                    value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                        false, +                        Span::call_site() +                    )))                  }              ])          ); diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 04720c1..5f5b0b6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -44,14 +44,29 @@ const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION");  ///  /// # Arguments  /// * (Optional) A interface trait the struct implements. -/// * (Zero or more) Flags. Like `a = true, b = false` +/// * (Zero or more) Comma separated flags. Each flag being formatted `name=value`.  ///  /// # Flags -/// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from -///   documentation. -/// - `no_declare_concrete_interface` - Disable declaring the concrete type as the -///   interface when no interface trait argument is given. -/// - `async` - Mark as async. +/// #### `no_doc_hidden` +/// **Value:** boolean literal<br> +/// **Default:** `false`<br> +/// Don't hide the impl of the [`Injectable`] trait from documentation. +/// +/// #### `no_declare_concrete_interface` +/// **Value:** boolean literal<br> +/// **Default:** `false`<br> +/// Disable declaring the concrete type as the interface when no interface trait argument +/// is given. +/// +/// #### `async` +/// **Value:** boolean literal<br> +/// **Default:** `false`<br> +/// Mark as async. +/// +/// #### `constructor` +/// **Value:** identifier<br> +/// **Default:** `new`<br> +/// Constructor method name.  ///  /// # Panics  /// If the attributed item is not a impl. @@ -132,6 +147,8 @@ const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION");  #[proc_macro_attribute]  pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenStream  { +    use quote::format_ident; +      let input_stream: proc_macro2::TokenStream = input_stream.into();      set_dummy(input_stream.clone()); @@ -143,28 +160,39 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS      let no_doc_hidden = args          .flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "no_doc_hidden") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      let no_declare_concrete_interface = args          .flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "no_declare_concrete_interface") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "no_declare_concrete_interface") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort(); + +    let constructor = args +        .flags +        .iter() +        .find(|flag| flag.name() == "constructor") +        .map_or(Ok(format_ident!("new")), MacroFlag::get_ident) +        .unwrap_or_abort();      let is_async_flag = args          .flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "async") +        .find(|flag| flag.name() == "async")          .cloned()          .unwrap_or_else(|| MacroFlag::new_off("async")); +    let is_async = is_async_flag.get_bool().unwrap_or_abort(); +      #[cfg(not(feature = "async"))] -    if is_async_flag.is_on() { +    if is_async {          use proc_macro_error::abort;          abort!( -            is_async_flag.flag.span(), +            is_async_flag.name().span(),              "The 'async' Cargo feature must be enabled to use this flag";              suggestion = "In your Cargo.toml: syrette = {{ version = \"{}\", features = [\"async\"] }}",              PACKAGE_VERSION @@ -172,9 +200,9 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS      }      let injectable_impl = -        InjectableImpl::<Dependency>::parse(input_stream).unwrap_or_abort(); +        InjectableImpl::<Dependency>::parse(input_stream, &constructor).unwrap_or_abort(); -    set_dummy(if is_async_flag.is_on() { +    set_dummy(if is_async {          injectable_impl.expand_dummy_async_impl()      } else {          injectable_impl.expand_dummy_blocking_impl() @@ -182,8 +210,7 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS      injectable_impl.validate().unwrap_or_abort(); -    let expanded_injectable_impl = -        injectable_impl.expand(no_doc_hidden, is_async_flag.is_on()); +    let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async);      let self_type = &injectable_impl.self_type; @@ -196,7 +223,7 @@ pub fn injectable(args_stream: TokenStream, input_stream: TokenStream) -> TokenS      });      let maybe_decl_interface = if let Some(interface) = opt_interface { -        let async_flag = if is_async_flag.is_on() { +        let async_flag = if is_async {              quote! {, async = true}          } else {              quote! {} @@ -277,13 +304,15 @@ pub fn factory(args_stream: TokenStream, input_stream: TokenStream) -> TokenStre      let mut is_threadsafe = flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "threadsafe") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "threadsafe") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      let is_async = flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "async") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "async") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      if is_async {          is_threadsafe = true; @@ -377,13 +406,15 @@ pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream      let mut is_threadsafe = flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "threadsafe") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "threadsafe") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      let is_async = flags          .iter() -        .find(|flag| flag.flag.to_string().as_str() == "async") -        .map_or(false, |flag| flag.is_on.value); +        .find(|flag| flag.name() == "async") +        .map_or(Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      if is_async {          is_threadsafe = true; @@ -446,12 +477,11 @@ pub fn declare_interface(input: TokenStream) -> TokenStream          flags,      } = parse(input).unwrap_or_abort(); -    let opt_async_flag = flags -        .iter() -        .find(|flag| flag.flag.to_string().as_str() == "async"); +    let opt_async_flag = flags.iter().find(|flag| flag.name() == "async"); -    let is_async = -        opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value); +    let is_async = opt_async_flag +        .map_or_else(|| Ok(false), MacroFlag::get_bool) +        .unwrap_or_abort();      let interface_type = if interface == implementation {          Type::Path(interface) diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs index f0e3a70..ba71cc2 100644 --- a/macros/src/macro_flag.rs +++ b/macros/src/macro_flag.rs @@ -2,13 +2,15 @@ use std::hash::Hash;  use proc_macro2::Span;  use syn::parse::{Parse, ParseStream}; -use syn::{Ident, LitBool, Token}; +use syn::{Ident, Lit, LitBool, Token}; -#[derive(Debug, Eq, Clone)] +use crate::util::error::diagnostic_error_enum; + +#[derive(Debug, Clone)]  pub struct MacroFlag  { -    pub flag: Ident, -    pub is_on: LitBool, +    pub name: Ident, +    pub value: MacroFlagValue,  }  impl MacroFlag @@ -16,14 +18,41 @@ impl MacroFlag      pub fn new_off(flag: &str) -> Self      {          Self { -            flag: Ident::new(flag, Span::call_site()), -            is_on: LitBool::new(false, Span::call_site()), +            name: Ident::new(flag, Span::call_site()), +            value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                false, +                Span::call_site(), +            ))),          }      } -    pub fn is_on(&self) -> bool +    pub fn name(&self) -> &Ident      { -        self.is_on.value +        &self.name +    } + +    pub fn get_bool(&self) -> Result<bool, MacroFlagError> +    { +        if let MacroFlagValue::Literal(Lit::Bool(lit_bool)) = &self.value { +            return Ok(lit_bool.value); +        } + +        Err(MacroFlagError::UnexpectedValueKind { +            expected: "boolean literal", +            value_span: self.value.span(), +        }) +    } + +    pub fn get_ident(&self) -> Result<Ident, MacroFlagError> +    { +        if let MacroFlagValue::Identifier(ident) = &self.value { +            return Ok(ident.clone()); +        } + +        Err(MacroFlagError::UnexpectedValueKind { +            expected: "identifier", +            value_span: self.value.span(), +        })      }  } @@ -31,13 +60,13 @@ impl Parse for MacroFlag  {      fn parse(input: ParseStream) -> syn::Result<Self>      { -        let flag = input.parse::<Ident>()?; +        let name = input.parse::<Ident>()?;          input.parse::<Token![=]>()?; -        let is_on: LitBool = input.parse()?; +        let value: MacroFlagValue = input.parse()?; -        Ok(Self { flag, is_on }) +        Ok(Self { name, value })      }  } @@ -45,15 +74,59 @@ impl PartialEq for MacroFlag  {      fn eq(&self, other: &Self) -> bool      { -        self.flag == other.flag +        self.name == other.name      }  } +impl Eq for MacroFlag {} +  impl Hash for MacroFlag  {      fn hash<H: std::hash::Hasher>(&self, state: &mut H)      { -        self.flag.hash(state); +        self.name.hash(state); +    } +} + +diagnostic_error_enum! { +pub enum MacroFlagError { +    #[error("Expected a {expected}"), span = value_span] +    UnexpectedValueKind { +        expected: &'static str, +        value_span: Span +    }, +} +} + +#[derive(Debug, Clone)] +pub enum MacroFlagValue +{ +    Literal(Lit), +    Identifier(Ident), +} + +impl MacroFlagValue +{ +    fn span(&self) -> Span +    { +        match self { +            Self::Literal(lit) => lit.span(), +            Self::Identifier(ident) => ident.span(), +        } +    } +} + +impl Parse for MacroFlagValue +{ +    fn parse(input: ParseStream) -> syn::Result<Self> +    { +        if let Ok(lit) = input.parse::<Lit>() { +            return Ok(Self::Literal(lit)); +        }; + +        input.parse::<Ident>().map(Self::Identifier).map_err(|err| { +            syn::Error::new(err.span(), "Expected a literal or a identifier") +        })      }  } @@ -76,8 +149,11 @@ mod tests                  more = true              })?,              MacroFlag { -                flag: format_ident!("more"), -                is_on: LitBool::new(true, Span::call_site()) +                name: format_ident!("more"), +                value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                    true, +                    Span::call_site() +                )))              }          ); @@ -86,8 +162,11 @@ mod tests                  do_something = false              })?,              MacroFlag { -                flag: format_ident!("do_something"), -                is_on: LitBool::new(false, Span::call_site()) +                name: format_ident!("do_something"), +                value: MacroFlagValue::Literal(Lit::Bool(LitBool::new( +                    false, +                    Span::call_site() +                )))              }          ); diff --git a/macros/src/util/item_impl.rs b/macros/src/util/item_impl.rs index 4bd7492..621f6be 100644 --- a/macros/src/util/item_impl.rs +++ b/macros/src/util/item_impl.rs @@ -1,15 +1,16 @@ +use proc_macro2::Ident;  use syn::{ImplItem, ImplItemMethod, ItemImpl};  pub fn find_impl_method_by_name_mut<'item_impl>(      item_impl: &'item_impl mut ItemImpl, -    method_name: &'static str, +    method_name: &Ident,  ) -> Option<&'item_impl mut ImplItemMethod>  {      let impl_items = &mut item_impl.items;      impl_items.iter_mut().find_map(|impl_item| match impl_item {          ImplItem::Method(method_item) => { -            if method_item.sig.ident == method_name { +            if &method_item.sig.ident == method_name {                  Some(method_item)              } else {                  None | 
