From a9ff2f16812b56107604400a64a7f482d017eca1 Mon Sep 17 00:00:00 2001
From: HampusM <hampus@hampusmat.com>
Date: Wed, 31 Aug 2022 21:41:27 +0200
Subject: feat: add a threadsafe flag to the declare_default_factory macro

---
 macros/src/decl_def_factory_args.rs | 56 ++++++++++++++++++++++++
 macros/src/lib.rs                   | 85 ++++++++++++++++++++++++++++++++++++-
 2 files changed, 140 insertions(+), 1 deletion(-)
 create mode 100644 macros/src/decl_def_factory_args.rs

(limited to 'macros/src')

diff --git a/macros/src/decl_def_factory_args.rs b/macros/src/decl_def_factory_args.rs
new file mode 100644
index 0000000..6450583
--- /dev/null
+++ b/macros/src/decl_def_factory_args.rs
@@ -0,0 +1,56 @@
+use syn::parse::Parse;
+use syn::punctuated::Punctuated;
+use syn::{Token, Type};
+
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const FACTORY_MACRO_FLAGS: &[&str] = &["threadsafe"];
+
+pub struct DeclareDefaultFactoryMacroArgs
+{
+    pub interface: Type,
+    pub flags: Punctuated<MacroFlag, Token![,]>,
+}
+
+impl Parse for DeclareDefaultFactoryMacroArgs
+{
+    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self>
+    {
+        let interface = input.parse().unwrap();
+
+        if !input.peek(Token![,]) {
+            return Ok(Self {
+                interface,
+                flags: Punctuated::new(),
+            });
+        }
+
+        input.parse::<Token![,]>().unwrap();
+
+        let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+        for flag in &flags {
+            let flag_str = flag.flag.to_string();
+
+            if !FACTORY_MACRO_FLAGS.contains(&flag_str.as_str()) {
+                return Err(input.error(format!(
+                    "Unknown flag '{}'. Expected one of [ {} ]",
+                    flag_str,
+                    FACTORY_MACRO_FLAGS.join(",")
+                )));
+            }
+        }
+
+        let flag_names = flags
+            .iter()
+            .map(|flag| flag.flag.to_string())
+            .collect::<Vec<_>>();
+
+        if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+            return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+        }
+
+        Ok(Self { interface, flags })
+    }
+}
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 7d466aa..7083b44 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -6,8 +6,9 @@
 
 use proc_macro::TokenStream;
 use quote::quote;
-use syn::{parse, parse_macro_input};
+use syn::{parse, parse_macro_input, ItemTrait};
 
+mod decl_def_factory_args;
 mod declare_interface_args;
 mod dependency;
 mod factory_macro_args;
@@ -247,6 +248,88 @@ pub fn factory(args_stream: TokenStream, type_alias_stream: TokenStream) -> Toke
     .into()
 }
 
+/// Shortcut for declaring a default factory.
+///
+/// A default factory is a factory that doesn't take any arguments.
+///
+/// The more tedious way to accomplish what this macro does would be by using
+/// the [`factory`] macro.
+///
+/// *This macro is only available if Syrette is built with the "factory" feature.*
+///
+/// # Arguments
+/// - Interface trait
+/// * (Zero or more) Flags. Like `a = true, b = false`
+///
+/// # Flags
+/// - `threadsafe` - Mark as threadsafe.
+///
+/// # Panics
+/// If the provided arguments are invalid.
+///
+/// # Examples
+/// ```
+/// # use syrette::declare_default_factory;
+/// #
+/// trait IParser
+/// {
+///     // Methods and etc here...
+/// }
+///
+/// declare_default_factory!(dyn IParser);
+/// ```
+#[proc_macro]
+#[cfg(feature = "factory")]
+pub fn declare_default_factory(args_stream: TokenStream) -> TokenStream
+{
+    use crate::decl_def_factory_args::DeclareDefaultFactoryMacroArgs;
+
+    let DeclareDefaultFactoryMacroArgs { interface, flags } = parse(args_stream).unwrap();
+
+    let is_threadsafe = flags
+        .iter()
+        .find(|flag| flag.flag.to_string().as_str() == "threadsafe")
+        .map_or(false, |flag| flag.is_on.value);
+
+    if is_threadsafe {
+        return quote! {
+            syrette::declare_interface!(
+                syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+                    (),
+                    #interface,
+                > -> syrette::interfaces::factory::IFactory<(), #interface>,
+                async = true
+            );
+
+            syrette::declare_interface!(
+                syrette::castable_factory::threadsafe::ThreadsafeCastableFactory<
+                    (),
+                    #interface,
+                > -> syrette::interfaces::any_factory::AnyFactory,
+                async = true
+            );
+        }
+        .into();
+    }
+
+    quote! {
+        syrette::declare_interface!(
+            syrette::castable_factory::blocking::CastableFactory<
+                (),
+                #interface,
+            > -> syrette::interfaces::factory::IFactory<(), #interface>
+        );
+
+        syrette::declare_interface!(
+            syrette::castable_factory::blocking::CastableFactory<
+                (),
+                #interface,
+            > -> syrette::interfaces::any_factory::AnyFactory
+        );
+    }
+    .into()
+}
+
 /// Declares the interface trait of a implementation.
 ///
 /// # Arguments
-- 
cgit v1.2.3-18-g5258