From 0b914f415cb04c45d8655cae3828af264887d203 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sun, 18 Sep 2022 16:39:27 +0200 Subject: feat: add factory macro async flag --- macros/src/factory_macro_args.rs | 2 +- macros/src/lib.rs | 25 ++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) (limited to 'macros') diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs index 0cf1d66..dd80c1c 100644 --- a/macros/src/factory_macro_args.rs +++ b/macros/src/factory_macro_args.rs @@ -5,7 +5,7 @@ use syn::Token; use crate::macro_flag::MacroFlag; use crate::util::iterator_ext::IteratorExt; -pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"]; +pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe", "async"]; pub struct FactoryMacroArgs { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 27577c7..2715f3d 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -161,6 +161,8 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt /// /// # Flags /// - `threadsafe` - Mark as threadsafe. +/// - `async` - Mark as async. Infers the `threadsafe` flag. The return type is +/// automatically put inside of a pinned boxed future. /// /// # Panics /// If the attributed item is not a type alias. @@ -197,11 +199,20 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke let FactoryMacroArgs { flags } = parse(args_stream).unwrap(); - let is_threadsafe = flags + let mut is_threadsafe = flags .iter() .find(|flag| flag.flag.to_string().as_str() == "threadsafe") .map_or(false, |flag| flag.is_on.value); + let is_async = flags + .iter() + .find(|flag| flag.flag.to_string().as_str() == "async") + .map_or(false, |flag| flag.is_on.value); + + if is_async { + is_threadsafe = true; + } + let factory_type_alias::FactoryTypeAlias { mut type_alias, mut factory_interface, @@ -212,8 +223,16 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke let output = factory_interface.output.clone(); factory_interface.output = parse( - quote! { - syrette::ptr::TransientPtr<#output> + if is_async { + quote! { + std::pin::Pin> + >> + } + } else { + quote! { + syrette::ptr::TransientPtr<#output> + } } .into(), ) -- cgit v1.2.3-18-g5258