aboutsummaryrefslogtreecommitdiff
path: root/macros/src/injectable
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-09-24 13:13:20 +0200
committerHampusM <hampus@hampusmat.com>2022-09-24 13:13:20 +0200
commit695f90bf900015df1e2728445f833dabced838a9 (patch)
treec68f2b483e3d20f400d27d4df159b2aec94d072f /macros/src/injectable
parent3ed020425bfd1fc5fedfa89a7ce20207bedcf5bc (diff)
refactor: reorganize modules in the macros crate
Diffstat (limited to 'macros/src/injectable')
-rw-r--r--macros/src/injectable/dependency.rs81
-rw-r--r--macros/src/injectable/implementation.rs261
-rw-r--r--macros/src/injectable/macro_args.rs67
-rw-r--r--macros/src/injectable/mod.rs4
-rw-r--r--macros/src/injectable/named_attr_input.rs21
5 files changed, 434 insertions, 0 deletions
diff --git a/macros/src/injectable/dependency.rs b/macros/src/injectable/dependency.rs
new file mode 100644
index 0000000..2c5e0fd
--- /dev/null
+++ b/macros/src/injectable/dependency.rs
@@ -0,0 +1,81 @@
+use std::error::Error;
+
+use proc_macro2::Ident;
+use syn::{parse2, FnArg, GenericArgument, LitStr, PathArguments, Type};
+
+use crate::injectable::named_attr_input::NamedAttrInput;
+use crate::util::syn_path::syn_path_to_string;
+
+pub struct Dependency
+{
+ pub interface: Type,
+ pub ptr: Ident,
+ pub name: Option<LitStr>,
+}
+
+impl Dependency
+{
+ pub fn build(new_method_arg: &FnArg) -> Result<Self, Box<dyn Error>>
+ {
+ let typed_new_method_arg = match new_method_arg {
+ FnArg::Typed(typed_arg) => Ok(typed_arg),
+ FnArg::Receiver(_) => Err("Unexpected self argument in 'new' method"),
+ }?;
+
+ let ptr_type_path = match typed_new_method_arg.ty.as_ref() {
+ Type::Path(arg_type_path) => Ok(arg_type_path),
+ Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() {
+ Type::Path(arg_type_path) => Ok(arg_type_path),
+ &_ => Err("Unexpected reference to non-path type"),
+ },
+ &_ => Err("Expected a path or a reference type"),
+ }?;
+
+ let ptr_path_segment = ptr_type_path.path.segments.last().map_or_else(
+ || Err("Expected pointer type path to have a last segment"),
+ Ok,
+ )?;
+
+ let ptr = ptr_path_segment.ident.clone();
+
+ let ptr_path_generic_args = &match &ptr_path_segment.arguments {
+ PathArguments::AngleBracketed(generic_args) => Ok(generic_args),
+ &_ => Err("Expected pointer type to have a generic type argument"),
+ }?
+ .args;
+
+ let interface = if let Some(GenericArgument::Type(interface)) =
+ ptr_path_generic_args.first()
+ {
+ Ok(interface.clone())
+ } else {
+ Err("Expected pointer type to have a generic type argument")
+ }?;
+
+ let arg_attrs = &typed_new_method_arg.attrs;
+
+ let opt_named_attr = arg_attrs.iter().find(|attr| {
+ attr.path.get_ident().map_or_else(
+ || false,
+ |attr_ident| attr_ident.to_string().as_str() == "named",
+ ) || syn_path_to_string(&attr.path) == "syrette::named"
+ });
+
+ let opt_named_attr_tokens = opt_named_attr.map(|attr| &attr.tokens);
+
+ let opt_named_attr_input =
+ if let Some(named_attr_tokens) = opt_named_attr_tokens {
+ Some(parse2::<NamedAttrInput>(named_attr_tokens.clone()).map_err(
+ |err| format!("Invalid input for 'named' attribute. {}", err),
+ )?)
+ } else {
+ None
+ };
+
+ Ok(Self {
+ interface,
+ ptr,
+ name: opt_named_attr_input.map(|named_attr_input| named_attr_input.name),
+ })
+ }
+}
diff --git a/macros/src/injectable/implementation.rs b/macros/src/injectable/implementation.rs
new file mode 100644
index 0000000..a84e798
--- /dev/null
+++ b/macros/src/injectable/implementation.rs
@@ -0,0 +1,261 @@
+use std::error::Error;
+
+use quote::{format_ident, quote, ToTokens};
+use syn::parse::{Parse, ParseStream};
+use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type};
+
+use crate::injectable::dependency::Dependency;
+use crate::util::item_impl::find_impl_method_by_name_mut;
+use crate::util::string::camelcase_to_snakecase;
+use crate::util::syn_path::syn_path_to_string;
+
+const DI_CONTAINER_VAR_NAME: &str = "di_container";
+const DEPENDENCY_HISTORY_VAR_NAME: &str = "dependency_history";
+
+pub struct InjectableImpl
+{
+ pub dependencies: Vec<Dependency>,
+ pub self_type: Type,
+ pub generics: Generics,
+ pub original_impl: ItemImpl,
+}
+
+impl Parse for InjectableImpl
+{
+ fn parse(input: ParseStream) -> syn::Result<Self>
+ {
+ let mut impl_parsed_input = input.parse::<ItemImpl>()?;
+
+ let dependencies = Self::build_dependencies(&mut impl_parsed_input)
+ .map_err(|err| input.error(err))?;
+
+ Ok(Self {
+ dependencies,
+ self_type: impl_parsed_input.self_ty.as_ref().clone(),
+ generics: impl_parsed_input.generics.clone(),
+ original_impl: impl_parsed_input,
+ })
+ }
+}
+
+impl InjectableImpl
+{
+ pub fn expand(&self, no_doc_hidden: bool, is_async: bool)
+ -> proc_macro2::TokenStream
+ {
+ let Self {
+ dependencies,
+ self_type,
+ generics,
+ original_impl,
+ } = self;
+
+ let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME);
+ let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME);
+
+ let maybe_doc_hidden = if no_doc_hidden {
+ quote! {}
+ } else {
+ quote! {
+ #[doc(hidden)]
+ }
+ };
+
+ let maybe_prevent_circular_deps = if cfg!(feature = "prevent-circular") {
+ quote! {
+ if #dependency_history_var.contains(&self_type_name) {
+ #dependency_history_var.push(self_type_name);
+
+ let dependency_trace =
+ syrette::dependency_trace::create_dependency_trace(
+ #dependency_history_var.as_slice(),
+ self_type_name
+ );
+
+ return Err(InjectableError::DetectedCircular {dependency_trace });
+ }
+
+ #dependency_history_var.push(self_type_name);
+ }
+ } else {
+ quote! {}
+ };
+
+ let injectable_impl = if is_async {
+ let async_get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, true);
+
+ quote! {
+ #maybe_doc_hidden
+ #[syrette::libs::async_trait::async_trait]
+ impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type
+ {
+ async fn resolve(
+ #di_container_var: &std::sync::Arc<syrette::async_di_container::AsyncDIContainer>,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
+
+ use syrette::errors::injectable::InjectableError;
+
+ let self_type_name = type_name::<#self_type>();
+
+ #maybe_prevent_circular_deps
+
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#async_get_dep_method_calls),*
+ )));
+ }
+ }
+
+ }
+ } else {
+ let get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, false);
+
+ quote! {
+ #maybe_doc_hidden
+ impl #generics syrette::interfaces::injectable::Injectable for #self_type
+ {
+ fn resolve(
+ #di_container_var: &std::rc::Rc<syrette::DIContainer>,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
+
+ use syrette::errors::injectable::InjectableError;
+
+ let self_type_name = type_name::<#self_type>();
+
+ #maybe_prevent_circular_deps
+
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#get_dep_method_calls),*
+ )));
+ }
+ }
+ }
+ };
+
+ quote! {
+ #original_impl
+
+ #injectable_impl
+ }
+ }
+
+ fn create_get_dep_method_calls(
+ dependencies: &[Dependency],
+ is_async: bool,
+ ) -> Vec<proc_macro2::TokenStream>
+ {
+ dependencies
+ .iter()
+ .filter_map(|dependency| {
+ let dep_interface_str = match &dependency.interface {
+ Type::TraitObject(interface_trait) => {
+ Some(interface_trait.to_token_stream().to_string())
+ }
+ Type::Path(path_interface) => {
+ Some(syn_path_to_string(&path_interface.path))
+ }
+ &_ => None,
+ }?;
+
+ let method_call = parse_str::<ExprMethodCall>(
+ format!(
+ "{}.get_bound::<{}>({}.clone(), {})",
+ DI_CONTAINER_VAR_NAME,
+ dep_interface_str,
+ DEPENDENCY_HISTORY_VAR_NAME,
+ dependency.name.as_ref().map_or_else(
+ || "None".to_string(),
+ |name| format!("Some(\"{}\")", name.value())
+ )
+ )
+ .as_str(),
+ )
+ .ok()?;
+
+ Some((method_call, dependency))
+ })
+ .map(|(method_call, dep_type)| {
+ let ptr_name = dep_type.ptr.to_string();
+
+ let to_ptr = format_ident!(
+ "{}",
+ camelcase_to_snakecase(&ptr_name.replace("Ptr", ""))
+ );
+
+ let do_method_call = if is_async {
+ quote! { #method_call.await }
+ } else {
+ quote! { #method_call }
+ };
+
+ let resolve_failed_error = if is_async {
+ quote! { InjectableError::AsyncResolveFailed }
+ } else {
+ quote! { InjectableError::ResolveFailed }
+ };
+
+ quote! {
+ #do_method_call.map_err(|err| #resolve_failed_error {
+ reason: Box::new(err),
+ affected: self_type_name
+ })?.#to_ptr().unwrap()
+ }
+ })
+ .collect()
+ }
+
+ fn build_dependencies(
+ item_impl: &mut ItemImpl,
+ ) -> Result<Vec<Dependency>, Box<dyn Error>>
+ {
+ let new_method_impl_item = find_impl_method_by_name_mut(item_impl, "new")
+ .map_or_else(|| Err("Missing a 'new' method"), Ok)?;
+
+ let new_method_args = &mut new_method_impl_item.sig.inputs;
+
+ let dependencies: Result<Vec<_>, _> =
+ new_method_args.iter().map(Dependency::build).collect();
+
+ for arg in new_method_args {
+ let typed_arg = if let FnArg::Typed(typed_arg) = arg {
+ typed_arg
+ } else {
+ continue;
+ };
+
+ let attrs_to_remove: Vec<_> = typed_arg
+ .attrs
+ .iter()
+ .enumerate()
+ .filter_map(|(index, attr)| {
+ if syn_path_to_string(&attr.path).as_str() == "syrette::named" {
+ return Some(index);
+ }
+
+ if attr.path.get_ident()?.to_string().as_str() == "named" {
+ return Some(index);
+ }
+
+ None
+ })
+ .collect();
+
+ for attr_index in attrs_to_remove {
+ typed_arg.attrs.remove(attr_index);
+ }
+ }
+
+ dependencies
+ }
+}
diff --git a/macros/src/injectable/macro_args.rs b/macros/src/injectable/macro_args.rs
new file mode 100644
index 0000000..50d4087
--- /dev/null
+++ b/macros/src/injectable/macro_args.rs
@@ -0,0 +1,67 @@
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::{Token, TypePath};
+
+use crate::macro_flag::MacroFlag;
+use crate::util::iterator_ext::IteratorExt;
+
+pub const INJECTABLE_MACRO_FLAGS: &[&str] = &["no_doc_hidden", "async"];
+
+pub struct InjectableMacroArgs
+{
+ pub interface: Option<TypePath>,
+ pub flags: Punctuated<MacroFlag, Token![,]>,
+}
+
+impl Parse for InjectableMacroArgs
+{
+ fn parse(input: ParseStream) -> syn::Result<Self>
+ {
+ let interface = input.parse::<TypePath>().ok();
+
+ if interface.is_some() {
+ let comma_input_lookahead = input.lookahead1();
+
+ if !comma_input_lookahead.peek(Token![,]) {
+ return Ok(Self {
+ interface,
+ flags: Punctuated::new(),
+ });
+ }
+
+ input.parse::<Token![,]>()?;
+ }
+
+ if input.is_empty() {
+ return Ok(Self {
+ interface,
+ flags: Punctuated::new(),
+ });
+ }
+
+ let flags = Punctuated::<MacroFlag, Token![,]>::parse_terminated(input)?;
+
+ for flag in &flags {
+ let flag_str = flag.flag.to_string();
+
+ if !INJECTABLE_MACRO_FLAGS.contains(&flag_str.as_str()) {
+ return Err(input.error(format!(
+ "Unknown flag '{}'. Expected one of [ {} ]",
+ flag_str,
+ INJECTABLE_MACRO_FLAGS.join(",")
+ )));
+ }
+ }
+
+ let flag_names = flags
+ .iter()
+ .map(|flag| flag.flag.to_string())
+ .collect::<Vec<_>>();
+
+ if let Some(dupe_flag_name) = flag_names.iter().find_duplicate() {
+ return Err(input.error(format!("Duplicate flag '{}'", dupe_flag_name)));
+ }
+
+ Ok(Self { interface, flags })
+ }
+}
diff --git a/macros/src/injectable/mod.rs b/macros/src/injectable/mod.rs
new file mode 100644
index 0000000..b713aeb
--- /dev/null
+++ b/macros/src/injectable/mod.rs
@@ -0,0 +1,4 @@
+pub mod dependency;
+pub mod implementation;
+pub mod macro_args;
+pub mod named_attr_input;
diff --git a/macros/src/injectable/named_attr_input.rs b/macros/src/injectable/named_attr_input.rs
new file mode 100644
index 0000000..5f7123c
--- /dev/null
+++ b/macros/src/injectable/named_attr_input.rs
@@ -0,0 +1,21 @@
+use syn::parse::Parse;
+use syn::{parenthesized, LitStr};
+
+pub struct NamedAttrInput
+{
+ pub name: LitStr,
+}
+
+impl Parse for NamedAttrInput
+{
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self>
+ {
+ let content;
+
+ parenthesized!(content in input);
+
+ Ok(Self {
+ name: content.parse()?,
+ })
+ }
+}