aboutsummaryrefslogtreecommitdiff
path: root/macros
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-08-29 20:52:56 +0200
committerHampusM <hampus@hampusmat.com>2022-08-29 21:01:32 +0200
commit080cc42bb1da09059dbc35049a7ded0649961e0c (patch)
tree307ee564124373616022c1ba2b4d5af80845cd92 /macros
parent6e31d8f9e46fece348f329763b39b9c6f2741c07 (diff)
feat: implement async functionality
Diffstat (limited to 'macros')
-rw-r--r--macros/Cargo.toml2
-rw-r--r--macros/src/declare_interface_args.rs43
-rw-r--r--macros/src/factory_macro_args.rs44
-rw-r--r--macros/src/injectable_impl.rs102
-rw-r--r--macros/src/injectable_macro_args.rs55
-rw-r--r--macros/src/lib.rs108
-rw-r--r--macros/src/libs/intertrait_macros/gen_caster.rs26
-rw-r--r--macros/src/macro_flag.rs27
-rw-r--r--macros/src/util/mod.rs1
-rw-r--r--macros/src/util/string.rs12
10 files changed, 330 insertions, 90 deletions
diff --git a/macros/Cargo.toml b/macros/Cargo.toml
index a929b08..28cb4c0 100644
--- a/macros/Cargo.toml
+++ b/macros/Cargo.toml
@@ -23,6 +23,8 @@ syn = { version = "1.0.96", features = ["full"] }
quote = "1.0.18"
proc-macro2 = "1.0.40"
uuid = { version = "0.8", features = ["v4"] }
+regex = "1.6.0"
+once_cell = "1.13.1"
[dev_dependencies]
syrette = { version = "0.3.0", path = "..", features = ["factory"] }
diff --git a/macros/src/declare_interface_args.rs b/macros/src/declare_interface_args.rs
index b54f458..bd2f24e 100644
--- a/macros/src/declare_interface_args.rs
+++ b/macros/src/declare_interface_args.rs
@@ -1,10 +1,17 @@
use syn::parse::{Parse, ParseStream, Result};
+use syn::punctuated::Punctuated;
use syn::{Path, Token, Type};
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const DECLARE_INTERFACE_FLAGS: &[&str] = &["async"];
+
pub struct DeclareInterfaceArgs
{
pub implementation: Type,
pub interface: Path,
+ pub flags: Punctuated<MacroFlag, Token![,]>,
}
impl Parse for DeclareInterfaceArgs
@@ -15,9 +22,43 @@ impl Parse for DeclareInterfaceArgs
input.parse::<Token![->]>()?;
+ let interface: Path = input.parse()?;
+
+ let flags = if input.peek(Token![,]) {
+ input.parse::<Token![,]>()?;
+
+ let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !DECLARE_INTERFACE_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ DECLARE_INTERFACE_FLAGS.join(",")
+ )));
+ }
+ }
+
+ let flag_names = flags
+ .iter()
+ .map(|flag| flag.flag.to_string())
+ .collect::<Vec<_>>();
+
+ if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+ return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+ }
+
+ flags
+ } else {
+ Punctuated::new()
+ };
+
Ok(Self {
implementation,
- interface: input.parse()?,
+ interface,
+ flags,
})
}
}
diff --git a/macros/src/factory_macro_args.rs b/macros/src/factory_macro_args.rs
new file mode 100644
index 0000000..57517d6
--- /dev/null
+++ b/macros/src/factory_macro_args.rs
@@ -0,0 +1,44 @@
+use syn::parse::Parse;
+use syn::punctuated::Punctuated;
+use syn::Token;
+
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const FACTORY_MACRO_FLAGS: &[&str] = &["async"];
+
+pub struct FactoryMacroArgs
+{
+ pub flags: Punctuated<MacroFlag, Token![,]>,
+}
+
+impl Parse for FactoryMacroArgs
+{
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self>
+ {
+ let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ FACTORY_MACRO_FLAGS.join(",")
+ )));
+ }
+ }
+
+ let flag_names = flags
+ .iter()
+ .map(|flag| flag.flag.to_string())
+ .collect::<Vec<_>>();
+
+ if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+ return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+ }
+
+ Ok(Self { flags })
+ }
+}
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index 990b148..3565ef9 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type};
use crate::dependency::Dependency;
use crate::util::item_impl::find_impl_method_by_name_mut;
+use crate::util::string::camelcase_to_snakecase;
use crate::util::syn_path::syn_path_to_string;
const DI_CONTAINER_VAR_NAME: &str = "di_container";
@@ -39,7 +40,8 @@ impl Parse for InjectableImpl
impl InjectableImpl
{
- pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream
+ pub fn expand(&self, no_doc_hidden: bool, is_async: bool)
+ -> proc_macro2::TokenStream
{
let Self {
dependencies,
@@ -51,8 +53,6 @@ impl InjectableImpl
let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME);
let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME);
- let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies);
-
let maybe_doc_hidden = if no_doc_hidden {
quote! {}
} else {
@@ -81,36 +81,78 @@ impl InjectableImpl
quote! {}
};
- quote! {
- #original_impl
+ let injectable_impl = if is_async {
+ let async_get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, true);
+
+ quote! {
+ #maybe_doc_hidden
+ #[syrette::libs::async_trait::async_trait]
+ impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type
+ {
+ async fn resolve(
+ #di_container_var: &syrette::async_di_container::AsyncDIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
+
+ use syrette::errors::injectable::InjectableError;
+
+ let self_type_name = type_name::<#self_type>();
+
+ #maybe_prevent_circular_deps
+
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#async_get_dep_method_calls),*
+ )));
+ }
+ }
- #maybe_doc_hidden
- impl #generics syrette::interfaces::injectable::Injectable for #self_type {
- fn resolve(
- #di_container_var: &syrette::DIContainer,
- mut #dependency_history_var: Vec<&'static str>,
- ) -> Result<
- syrette::ptr::TransientPtr<Self>,
- syrette::errors::injectable::InjectableError>
+ }
+ } else {
+ let get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, false);
+
+ quote! {
+ #maybe_doc_hidden
+ impl #generics syrette::interfaces::injectable::Injectable for #self_type
{
- use std::any::type_name;
+ fn resolve(
+ #di_container_var: &syrette::DIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
- use syrette::errors::injectable::InjectableError;
+ use syrette::errors::injectable::InjectableError;
- let self_type_name = type_name::<#self_type>();
+ let self_type_name = type_name::<#self_type>();
- #maybe_prevent_circular_deps
+ #maybe_prevent_circular_deps
- return Ok(syrette::ptr::TransientPtr::new(Self::new(
- #(#get_dep_method_calls),*
- )));
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#get_dep_method_calls),*
+ )));
+ }
}
}
+ };
+
+ quote! {
+ #original_impl
+
+ #injectable_impl
}
}
fn create_get_dep_method_calls(
dependencies: &[Dependency],
+ is_async: bool,
) -> Vec<proc_macro2::TokenStream>
{
dependencies
@@ -146,11 +188,25 @@ impl InjectableImpl
.map(|(method_call, dep_type)| {
let ptr_name = dep_type.ptr.to_string();
- let to_ptr =
- format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase());
+ let to_ptr = format_ident!(
+ "{}",
+ camelcase_to_snakecase(&ptr_name.replace("Ptr", ""))
+ );
+
+ let do_method_call = if is_async {
+ quote! { #method_call.await }
+ } else {
+ quote! { #method_call }
+ };
+
+ let resolve_failed_error = if is_async {
+ quote! { InjectableError::AsyncResolveFailed }
+ } else {
+ quote! { InjectableError::ResolveFailed }
+ };
quote! {
- #method_call.map_err(|err| InjectableError::ResolveFailed {
+ #do_method_call.map_err(|err| #resolve_failed_error {
reason: Box::new(err),
affected: self_type_name
})?.#to_ptr().unwrap()
diff --git a/macros/src/injectable_macro_args.rs b/macros/src/injectable_macro_args.rs
index 43f8e11..6cc1d7e 100644
--- a/macros/src/injectable_macro_args.rs
+++ b/macros/src/injectable_macro_args.rs
@@ -1,49 +1,16 @@
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
-use syn::{braced, Ident, LitBool, Token, TypePath};
+use syn::{braced, Token, TypePath};
+use crate::macro_flag::MacroFlag;
use crate::util::iterator_ext::IteratorExt;
-pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden"];
-
-pub struct InjectableMacroFlag
-{
- pub flag: Ident,
- pub is_on: LitBool,
-}
-
-impl Parse for InjectableMacroFlag
-{
- fn parse(input: ParseStream) -> syn::Result<Self>
- {
- let input_forked = input.fork();
-
- let flag: Ident = input_forked.parse()?;
-
- let flag_str = flag.to_string();
-
- if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) {
- return Err(input.error(format!(
- "Unknown flag '{}'. Expected one of [ {} ]",
- flag_str,
- INJECTABLE_MACRO_FLAGS.join(",")
- )));
- }
-
- input.parse::<Ident>()?;
-
- input.parse::<Token![=]>()?;
-
- let is_on: LitBool = input.parse()?;
-
- Ok(Self { flag, is_on })
- }
-}
+pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"];
pub struct InjectableMacroArgs
{
pub interface: Option<TypePath>,
- pub flags: Punctuated<InjectableMacroFlag, Token![,]>,
+ pub flags: Punctuated<MacroFlag, Token![,]>,
}
impl Parse for InjectableMacroArgs
@@ -76,7 +43,19 @@ impl Parse for InjectableMacroArgs
braced!(braced_content in input);
- let flags = braced_content.parse_terminated(InjectableMacroFlag::parse)?;
+ let flags = braced_content.parse_terminated(MacroFlag::parse)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ INJECTABLE_MACRO_FLAGS.join(",")
+ )));
+ }
+ }
let flag_names = flags
.iter()
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index eb3a2be..40fbb53 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -2,7 +2,7 @@
#![deny(clippy::pedantic)]
#![deny(missing_docs)]
-//! Macros for the [Syrette](https://crates.io/crates/syrette) crate.
+//! Macros for the [Sy&rette](https://crates.io/crates/syrette) crate.
use proc_macro::TokenStream;
use quote::quote;
@@ -10,10 +10,12 @@ use syn::{parse, parse_macro_input};
mod declare_interface_args;
mod dependency;
+mod factory_macro_args;
mod factory_type_alias;
mod injectable_impl;
mod injectable_macro_args;
mod libs;
+mod macro_flag;
mod named_attr_input;
mod util;
@@ -31,6 +33,7 @@ use libs::intertrait_macros::gen_caster::generate_caster;
/// # Flags
/// - `no_doc_hidden` - Don't hide the impl of the [`Injectable`] trait from
/// documentation.
+/// - `async` - Mark as async.
///
/// # Panics
/// If the attributed item is not a impl.
@@ -107,21 +110,31 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
{
let InjectableMacroArgs { interface, flags } = parse_macro_input!(args_stream);
- let mut flags_iter = flags.iter();
-
- let no_doc_hidden = flags_iter
+ let no_doc_hidden = flags
+ .iter()
.find(|flag| flag.flag.to_string().as_str() == "no_doc_hidden")
.map_or(false, |flag| flag.is_on.value);
+ let is_async = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async")
+ .map_or(false, |flag| flag.is_on.value);
+
let injectable_impl: InjectableImpl = parse(impl_stream).unwrap();
- let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden);
+ let expanded_injectable_impl = injectable_impl.expand(no_doc_hidden, is_async);
let maybe_decl_interface = if interface.is_some() {
let self_type = &injectable_impl.self_type;
- quote! {
- syrette::declare_interface!(#self_type -> #interface);
+ if is_async {
+ quote! {
+ syrette::declare_interface!(#self_type -> #interface, async = true);
+ }
+ } else {
+ quote! {
+ syrette::declare_interface!(#self_type -> #interface);
+ }
}
} else {
quote! {}
@@ -139,6 +152,12 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
///
/// *This macro is only available if Syrette is built with the "factory" feature.*
///
+/// # Arguments
+/// * (Zero or more) Flags. Like `a = true, b = false`
+///
+/// # Flags
+/// - `async` - Mark as async.
+///
/// # Panics
/// If the attributed item is not a type alias.
///
@@ -166,8 +185,17 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
/// ```
#[proc_macro_attribute]
#[cfg(feature = "factory")]
-pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
+pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> TokenStream
{
+ use crate::factory_macro_args::FactoryMacroArgs;
+
+ let FactoryMacroArgs { flags } = parse(args_stream).unwrap();
+
+ let is_async = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async")
+ .map_or(false, |flag| flag.is_on.value);
+
let factory_type_alias::FactoryTypeAlias {
type_alias,
factory_interface,
@@ -175,22 +203,46 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
return_type,
} = parse(type_alias_stream).unwrap();
+ let decl_interfaces = if is_async {
+ quote! {
+ syrette::declare_interface!(
+ syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+ #arg_types,
+ #return_type
+ > -> #factory_interface,
+ async = true
+ );
+
+ syrette::declare_interface!(
+ syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+ #arg_types,
+ #return_type
+ > -> syrette::interfaces::any_factory::AnyThreadsafeFactory,
+ async = true
+ )
+ }
+ } else {
+ quote! {
+ syrette::declare_interface!(
+ syrette::castable_factory::blocking::CastableFactory<
+ #arg_types,
+ #return_type
+ > -> #factory_interface
+ );
+
+ syrette::declare_interface!(
+ syrette::castable_factory::blocking::CastableFactory<
+ #arg_types,
+ #return_type
+ > -> syrette::interfaces::any_factory::AnyFactory
+ );
+ }
+ };
+
quote! {
#type_alias
- syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
- #arg_types,
- #return_type
- > -> #factory_interface
- );
-
- syrette::declare_interface!(
- syrette::castable_factory::CastableFactory<
- #arg_types,
- #return_type
- > -> syrette::interfaces::any_factory::AnyFactory
- );
+ #decl_interfaces
}
.into()
}
@@ -199,6 +251,10 @@ pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
///
/// # Arguments
/// {Implementation} -> {Interface}
+/// * (Zero or more) Flags. Like `a = true, b = false`
+///
+/// # Flags
+/// - `async` - Mark as async.
///
/// # Examples
/// ```
@@ -218,9 +274,17 @@ pub fn declare_interface(input: TokenStream) -> TokenStream
let DeclareInterfaceArgs {
implementation,
interface,
+ flags,
} = parse_macro_input!(input);
- generate_caster(&implementation, &interface).into()
+ let opt_async_flag = flags
+ .iter()
+ .find(|flag| flag.flag.to_string().as_str() == "async");
+
+ let is_async =
+ opt_async_flag.map_or_else(|| false, |async_flag| async_flag.is_on.value);
+
+ generate_caster(&implementation, &interface, is_async).into()
}
/// Declares the name of a dependency.
diff --git a/macros/src/libs/intertrait_macros/gen_caster.rs b/macros/src/libs/intertrait_macros/gen_caster.rs
index 9bac09e..df743e2 100644
--- a/macros/src/libs/intertrait_macros/gen_caster.rs
+++ b/macros/src/libs/intertrait_macros/gen_caster.rs
@@ -22,15 +22,29 @@ const CASTER_FN_NAME_PREFIX: &[u8] = b"__";
const FN_BUF_LEN: usize = CASTER_FN_NAME_PREFIX.len() + Simple::LENGTH;
-pub fn generate_caster(ty: &impl ToTokens, dst_trait: &impl ToTokens) -> TokenStream
+pub fn generate_caster(
+ ty: &impl ToTokens,
+ dst_trait: &impl ToTokens,
+ sync: bool,
+) -> TokenStream
{
let fn_ident = create_caster_fn_ident();
- let new_caster = quote! {
- syrette::libs::intertrait::Caster::<dyn #dst_trait>::new(
- |from| from.downcast::<#ty>().unwrap(),
- |from| from.downcast::<#ty>().unwrap(),
- )
+ let new_caster = if sync {
+ quote! {
+ syrette::libs::intertrait::Caster::<dyn #dst_trait>::new_sync(
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap()
+ )
+ }
+ } else {
+ quote! {
+ syrette::libs::intertrait::Caster::<dyn #dst_trait>::new(
+ |from| from.downcast::<#ty>().unwrap(),
+ |from| from.downcast::<#ty>().unwrap(),
+ )
+ }
};
quote! {
diff --git a/macros/src/macro_flag.rs b/macros/src/macro_flag.rs
new file mode 100644
index 0000000..257a059
--- /dev/null
+++ b/macros/src/macro_flag.rs
@@ -0,0 +1,27 @@
+use syn::parse::{Parse, ParseStream};
+use syn::{Ident, LitBool, Token};
+
+#[derive(Debug)]
+pub struct MacroFlag
+{
+ pub flag: Ident,
+ pub is_on: LitBool,
+}
+
+impl Parse for MacroFlag
+{
+ fn parse(input: ParseStream) -> syn::Result<Self>
+ {
+ let input_forked = input.fork();
+
+ let flag: Ident = input_forked.parse()?;
+
+ input.parse::<Ident>()?;
+
+ input.parse::<Token![=]>()?;
+
+ let is_on: LitBool = input.parse()?;
+
+ Ok(Self { flag, is_on })
+ }
+}
diff --git a/macros/src/util/mod.rs b/macros/src/util/mod.rs
index 4f2a594..0705853 100644
--- a/macros/src/util/mod.rs
+++ b/macros/src/util/mod.rs
@@ -1,3 +1,4 @@
pub mod item_impl;
pub mod iterator_ext;
+pub mod string;
pub mod syn_path;
diff --git a/macros/src/util/string.rs b/macros/src/util/string.rs
new file mode 100644
index 0000000..90cccee
--- /dev/null
+++ b/macros/src/util/string.rs
@@ -0,0 +1,12 @@
+use once_cell::sync::Lazy;
+use regex::Regex;
+
+static CAMELCASE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"([a-z])([A-Z])").unwrap());
+
+pub fn camelcase_to_snakecase(camelcased: &str) -> String
+{
+ CAMELCASE_RE
+ .replace(camelcased, "${1}_$2")
+ .to_string()
+ .to_lowercase()
+}