aboutsummaryrefslogtreecommitdiff
path: root/syrette_macros/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'syrette_macros/src/lib.rs')
-rw-r--r--syrette_macros/src/lib.rs149
1 files changed, 138 insertions, 11 deletions
diff --git a/syrette_macros/src/lib.rs b/syrette_macros/src/lib.rs
index 1761534..0302c07 100644
--- a/syrette_macros/src/lib.rs
+++ b/syrette_macros/src/lib.rs
@@ -2,8 +2,8 @@ use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma,
- AttributeArgs, FnArg, GenericArgument, ImplItem, ItemImpl, Meta, NestedMeta, Path,
- PathArguments, Type, TypePath,
+ AttributeArgs, ExprMethodCall, FnArg, GenericArgument, ImplItem, ItemImpl, ItemType,
+ Meta, NestedMeta, Path, PathArguments, Type, TypeParamBound, TypePath,
};
mod libs;
@@ -27,8 +27,16 @@ const IMPL_NO_NEW_METHOD_ERR_MESSAGE: &str =
const IMPL_NEW_METHOD_SELF_PARAM_ERR_MESSAGE: &str =
"The new method of the attached to trait implementation cannot have a self parameter";
-const IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE: &str =
- "All parameters of the new method of the attached to trait implementation must be std::boxed::Box";
+const IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE: &str = concat!(
+ "All parameters of the new method of the attached to trait implementation ",
+ "must be either std::boxed::Box or std::rc:Rc (for factories)"
+);
+
+const INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE: &str =
+ "Invalid aliased trait. Must be 'dyn IFactory'";
+
+const INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE: &str =
+ "Invalid arguments for 'dyn IFactory'";
fn path_to_string(path: &Path) -> String
{
@@ -67,6 +75,10 @@ fn get_fn_arg_type_paths(fn_args: &Punctuated<FnArg, Comma>) -> Vec<TypePath>
match arg {
FnArg::Typed(typed_fn_arg) => match typed_fn_arg.ty.as_ref() {
Type::Path(arg_type_path) => acc.push(arg_type_path.clone()),
+ Type::Reference(ref_type_path) => match ref_type_path.elem.as_ref() {
+ Type::Path(arg_type_path) => acc.push(arg_type_path.clone()),
+ &_ => {}
+ },
&_ => {}
},
FnArg::Receiver(_receiver_fn_arg) => {}
@@ -109,8 +121,11 @@ fn get_dependency_types(item_impl: &ItemImpl) -> Vec<Type>
if arg_type_path_string != "Box"
&& arg_type_path_string != "std::boxed::Box"
&& arg_type_path_string != "boxed::Box"
+ && arg_type_path_string != "Rc"
+ && arg_type_path_string != "std::rc::Rc"
+ && arg_type_path_string != "rc::Rc"
{
- panic!("{}", IMPL_NEW_METHOD_BOX_PARAMS_ERR_MESSAGE);
+ panic!("{}", IMPL_NEW_METHOD_PARAM_TYPES_ERR_MESSAGE);
}
// Assume the type path has a last segment.
@@ -202,6 +217,49 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
let dependency_types = get_dependency_types(&item_impl);
+ let get_dependencies = dependency_types.iter().fold(
+ Vec::<ExprMethodCall>::new(),
+ |mut acc, dep_type| {
+ match dep_type {
+ Type::TraitObject(dep_type_trait) => {
+ acc.push(
+ parse_str(
+ format!(
+ "di_container.get::<{}>()",
+ dep_type_trait.to_token_stream()
+ )
+ .as_str(),
+ )
+ .unwrap(),
+ );
+ }
+ Type::Path(dep_type_path) => {
+ let dep_type_path_str = path_to_string(&dep_type_path.path);
+
+ let get_method_name = if dep_type_path_str.ends_with("Factory") {
+ "get_factory"
+ } else {
+ "get"
+ };
+
+ acc.push(
+ parse_str(
+ format!(
+ "di_container.{}::<{}>()",
+ get_method_name, dep_type_path_str
+ )
+ .as_str(),
+ )
+ .unwrap(),
+ );
+ }
+ &_ => {}
+ }
+
+ acc
+ },
+ );
+
quote! {
#item_impl
@@ -213,13 +271,15 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
use error_stack::ResultExt;
return Ok(Box::new(Self::new(
- #(di_container.get::<#dependency_types>()
+ #(#get_dependencies
.change_context(syrette::errors::injectable::ResolveError)
- .attach_printable(format!(
- "Unable to resolve a dependency of {}",
- std::any::type_name::<#self_type_path>()
- ))?,
- )*
+ .attach_printable(
+ format!(
+ "Unable to resolve a dependency of {}",
+ std::any::type_name::<#self_type_path>()
+ )
+ )?
+ ),*
)));
}
}
@@ -229,6 +289,73 @@ pub fn injectable(args_stream: TokenStream, impl_stream: TokenStream) -> TokenSt
.into()
}
+#[proc_macro_attribute]
+pub fn factory(_: TokenStream, type_alias_stream: TokenStream) -> TokenStream
+{
+ let type_alias: ItemType = parse(type_alias_stream).unwrap();
+
+ let aliased_trait = match &type_alias.ty.as_ref() {
+ Type::TraitObject(alias_type) => alias_type,
+ &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE),
+ };
+
+ if aliased_trait.bounds.len() != 1 {
+ panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE);
+ }
+
+ let type_bound = aliased_trait.bounds.first().unwrap();
+
+ let trait_bound = match type_bound {
+ TypeParamBound::Trait(trait_bound) => trait_bound,
+ &_ => panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE),
+ };
+
+ let trait_bound_path = &trait_bound.path;
+
+ if trait_bound_path.segments.is_empty()
+ || trait_bound_path.segments.last().unwrap().ident != "IFactory"
+ {
+ panic!("{}", INVALID_ALIASED_FACTORY_TRAIT_ERR_MESSAGE);
+ }
+
+ let factory_path_segment = trait_bound_path.segments.last().unwrap();
+
+ let factory_path_segment_args = &match &factory_path_segment.arguments {
+ syn::PathArguments::AngleBracketed(args) => args,
+ &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE),
+ }
+ .args;
+
+ let factory_arg_types_type = match &factory_path_segment_args[0] {
+ GenericArgument::Type(arg_type) => arg_type,
+ &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE),
+ };
+
+ let factory_return_type = match &factory_path_segment_args[1] {
+ GenericArgument::Type(arg_type) => arg_type,
+ &_ => panic!("{}", INVALID_ALIASED_FACTORY_ARGS_ERR_MESSAGE),
+ };
+
+ quote! {
+ #type_alias
+
+ syrette::castable_to!(
+ syrette::castable_factory::CastableFactory<
+ #factory_arg_types_type,
+ #factory_return_type
+ > => #trait_bound_path
+ );
+
+ syrette::castable_to!(
+ syrette::castable_factory::CastableFactory<
+ #factory_arg_types_type,
+ #factory_return_type
+ > => syrette::castable_factory::AnyFactory
+ );
+ }
+ .into()
+}
+
#[doc(hidden)]
#[proc_macro]
pub fn castable_to(input: TokenStream) -> TokenStream