aboutsummaryrefslogtreecommitdiff
path: root/macros/src/injectable_impl.rs
diff options
context:
space:
mode:
Diffstat (limited to 'macros/src/injectable_impl.rs')
-rw-r--r--macros/src/injectable_impl.rs102
1 files changed, 79 insertions, 23 deletions
diff --git a/macros/src/injectable_impl.rs b/macros/src/injectable_impl.rs
index 990b148..3565ef9 100644
--- a/macros/src/injectable_impl.rs
+++ b/macros/src/injectable_impl.rs
@@ -6,6 +6,7 @@ use syn::{parse_str, ExprMethodCall, FnArg, Generics, ItemImpl, Type};
use crate::dependency::Dependency;
use crate::util::item_impl::find_impl_method_by_name_mut;
+use crate::util::string::camelcase_to_snakecase;
use crate::util::syn_path::syn_path_to_string;
const DI_CONTAINER_VAR_NAME: &str = "di_container";
@@ -39,7 +40,8 @@ impl Parse for InjectableImpl
impl InjectableImpl
{
- pub fn expand(&self, no_doc_hidden: bool) -> proc_macro2::TokenStream
+ pub fn expand(&self, no_doc_hidden: bool, is_async: bool)
+ -> proc_macro2::TokenStream
{
let Self {
dependencies,
@@ -51,8 +53,6 @@ impl InjectableImpl
let di_container_var = format_ident!("{}", DI_CONTAINER_VAR_NAME);
let dependency_history_var = format_ident!("{}", DEPENDENCY_HISTORY_VAR_NAME);
- let get_dep_method_calls = Self::create_get_dep_method_calls(dependencies);
-
let maybe_doc_hidden = if no_doc_hidden {
quote! {}
} else {
@@ -81,36 +81,78 @@ impl InjectableImpl
quote! {}
};
- quote! {
- #original_impl
+ let injectable_impl = if is_async {
+ let async_get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, true);
+
+ quote! {
+ #maybe_doc_hidden
+ #[syrette::libs::async_trait::async_trait]
+ impl #generics syrette::interfaces::async_injectable::AsyncInjectable for #self_type
+ {
+ async fn resolve(
+ #di_container_var: &syrette::async_di_container::AsyncDIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
+
+ use syrette::errors::injectable::InjectableError;
+
+ let self_type_name = type_name::<#self_type>();
+
+ #maybe_prevent_circular_deps
+
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#async_get_dep_method_calls),*
+ )));
+ }
+ }
- #maybe_doc_hidden
- impl #generics syrette::interfaces::injectable::Injectable for #self_type {
- fn resolve(
- #di_container_var: &syrette::DIContainer,
- mut #dependency_history_var: Vec<&'static str>,
- ) -> Result<
- syrette::ptr::TransientPtr<Self>,
- syrette::errors::injectable::InjectableError>
+ }
+ } else {
+ let get_dep_method_calls =
+ Self::create_get_dep_method_calls(dependencies, false);
+
+ quote! {
+ #maybe_doc_hidden
+ impl #generics syrette::interfaces::injectable::Injectable for #self_type
{
- use std::any::type_name;
+ fn resolve(
+ #di_container_var: &syrette::DIContainer,
+ mut #dependency_history_var: Vec<&'static str>,
+ ) -> Result<
+ syrette::ptr::TransientPtr<Self>,
+ syrette::errors::injectable::InjectableError>
+ {
+ use std::any::type_name;
- use syrette::errors::injectable::InjectableError;
+ use syrette::errors::injectable::InjectableError;
- let self_type_name = type_name::<#self_type>();
+ let self_type_name = type_name::<#self_type>();
- #maybe_prevent_circular_deps
+ #maybe_prevent_circular_deps
- return Ok(syrette::ptr::TransientPtr::new(Self::new(
- #(#get_dep_method_calls),*
- )));
+ return Ok(syrette::ptr::TransientPtr::new(Self::new(
+ #(#get_dep_method_calls),*
+ )));
+ }
}
}
+ };
+
+ quote! {
+ #original_impl
+
+ #injectable_impl
}
}
fn create_get_dep_method_calls(
dependencies: &[Dependency],
+ is_async: bool,
) -> Vec<proc_macro2::TokenStream>
{
dependencies
@@ -146,11 +188,25 @@ impl InjectableImpl
.map(|(method_call, dep_type)| {
let ptr_name = dep_type.ptr.to_string();
- let to_ptr =
- format_ident!("{}", ptr_name.replace("Ptr", "").to_lowercase());
+ let to_ptr = format_ident!(
+ "{}",
+ camelcase_to_snakecase(&ptr_name.replace("Ptr", ""))
+ );
+
+ let do_method_call = if is_async {
+ quote! { #method_call.await }
+ } else {
+ quote! { #method_call }
+ };
+
+ let resolve_failed_error = if is_async {
+ quote! { InjectableError::AsyncResolveFailed }
+ } else {
+ quote! { InjectableError::ResolveFailed }
+ };
quote! {
- #method_call.map_err(|err| InjectableError::ResolveFailed {
+ #do_method_call.map_err(|err| #resolve_failed_error {
reason: Box::new(err),
affected: self_type_name
})?.#to_ptr().unwrap()