summaryrefslogtreecommitdiff
path: root/macros/src/lib.rs
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2023-03-26 16:30:19 +0200
committerHampusM <hampus@hampusmat.com>2023-03-26 16:37:54 +0200
commit7f9294869afd07e096e73a45e6a101b8970a0e6e (patch)
tree90705756cbd50fb81c964812717738109379fbbb /macros/src/lib.rs
parent9233c481d61271ee24b97fcb1820b459810e074c (diff)
feat: add automock attribute
Diffstat (limited to 'macros/src/lib.rs')
-rw-r--r--macros/src/lib.rs138
1 files changed, 106 insertions, 32 deletions
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 7fc062a..bcb4449 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -1,20 +1,27 @@
//! Macros for Ridicule, a mocking library supporting non-static generics.
#![deny(clippy::all, clippy::pedantic, missing_docs)]
use proc_macro::TokenStream;
+use proc_macro2::Ident;
use proc_macro_error::{proc_macro_error, ResultExt};
use quote::{format_ident, quote};
+use syn::token::Brace;
use syn::{
parse,
+ Block,
FnArg,
GenericArgument,
ImplItem,
+ ImplItemMethod,
+ ItemTrait,
Path,
PathArguments,
PathSegment,
ReturnType,
+ TraitItem,
Type,
TypeBareFn,
TypeParamBound,
+ Visibility,
};
use crate::expectation::Expectation;
@@ -70,9 +77,104 @@ pub fn mock(input_stream: TokenStream) -> TokenStream
let mock_mod_ident = format_ident!("__{mock_ident}");
- let method_items = input
- .item_impl
- .items
+ let method_items =
+ get_type_replaced_impl_item_methods(input.item_impl.items, &mock_ident);
+
+ let mock = Mock::new(
+ mock_ident.clone(),
+ input.mocked_trait,
+ &method_items,
+ input.item_impl.generics.clone(),
+ );
+
+ let expectations = method_items.iter().map(|item_method| {
+ Expectation::new(
+ &mock_ident,
+ item_method,
+ input.item_impl.generics.params.clone(),
+ )
+ });
+
+ quote! {
+ mod #mock_mod_ident {
+ use super::*;
+
+ #mock
+
+ #(#expectations)*
+ }
+
+ use #mock_mod_ident::#mock_ident;
+ }
+ .into()
+}
+
+/// Creates a mock automatically.
+#[proc_macro_attribute]
+#[proc_macro_error]
+pub fn automock(_: TokenStream, input_stream: TokenStream) -> TokenStream
+{
+ let item_trait = parse::<ItemTrait>(input_stream).unwrap_or_abort();
+
+ let mock_ident = format_ident!("Mock{}", item_trait.ident);
+
+ let mock_mod_ident = format_ident!("__{mock_ident}");
+
+ let method_items = get_type_replaced_impl_item_methods(
+ item_trait.items.iter().filter_map(|item| match item {
+ TraitItem::Method(item_method) => Some(ImplItem::Method(ImplItemMethod {
+ attrs: item_method.attrs.clone(),
+ vis: Visibility::Inherited,
+ defaultness: None,
+ sig: item_method.sig.clone(),
+ block: Block {
+ brace_token: Brace::default(),
+ stmts: vec![],
+ },
+ })),
+ _ => None,
+ }),
+ &mock_ident,
+ );
+
+ let mock = Mock::new(
+ mock_ident.clone(),
+ Path::new(
+ WithLeadingColons::No,
+ [PathSegment::new(item_trait.ident.clone(), None)],
+ ),
+ &method_items,
+ item_trait.generics.clone(),
+ );
+
+ let expectations = method_items.iter().map(|item_method| {
+ Expectation::new(&mock_ident, item_method, item_trait.generics.params.clone())
+ });
+
+ let visibility = &item_trait.vis;
+
+ quote! {
+ #item_trait
+
+ mod #mock_mod_ident {
+ use super::*;
+
+ #mock
+
+ #(#expectations)*
+ }
+
+ #visibility use #mock_mod_ident::#mock_ident;
+ }
+ .into()
+}
+
+fn get_type_replaced_impl_item_methods(
+ impl_items: impl IntoIterator<Item = ImplItem>,
+ mock_ident: &Ident,
+) -> Vec<ImplItemMethod>
+{
+ impl_items
.into_iter()
.filter_map(|item| match item {
ImplItem::Method(mut item_method) => {
@@ -117,35 +219,7 @@ pub fn mock(input_stream: TokenStream) -> TokenStream
}
_ => None,
})
- .collect::<Vec<_>>();
-
- let mock = Mock::new(
- mock_ident.clone(),
- input.mocked_trait,
- &method_items,
- input.item_impl.generics.clone(),
- );
-
- let expectations = method_items.iter().map(|item_method| {
- Expectation::new(
- &mock_ident,
- item_method,
- input.item_impl.generics.params.clone(),
- )
- });
-
- quote! {
- mod #mock_mod_ident {
- use super::*;
-
- #mock
-
- #(#expectations)*
- }
-
- use #mock_mod_ident::#mock_ident;
- }
- .into()
+ .collect()
}
fn replace_path_in_type(ty: Type, target_path: &Path, replacement_path: &Path) -> Type