summaryrefslogtreecommitdiff
path: root/macros/src/mock.rs
diff options
context:
space:
mode:
Diffstat (limited to 'macros/src/mock.rs')
-rw-r--r--macros/src/mock.rs77
1 files changed, 53 insertions, 24 deletions
diff --git a/macros/src/mock.rs b/macros/src/mock.rs
index 1b6dd67..8828b17 100644
--- a/macros/src/mock.rs
+++ b/macros/src/mock.rs
@@ -1,12 +1,14 @@
use proc_macro2::{Ident, Span, TokenStream};
use proc_macro_error::{abort_call_site, ResultExt};
use quote::{format_ident, quote, ToTokens};
+use syn::punctuated::Punctuated;
use syn::{
parse2,
AngleBracketedGenericArguments,
FnArg,
GenericArgument,
GenericParam,
+ Generics,
ImplItemMethod,
Lifetime,
Pat,
@@ -45,12 +47,17 @@ pub struct Mock
mocked_trait: Path,
expectations_fields: Vec<ExpectationsField>,
item_methods: Vec<ImplItemMethod>,
+ generics: Generics,
}
impl Mock
{
- pub fn new(ident: Ident, mocked_trait: Path, item_methods: &[ImplItemMethod])
- -> Self
+ pub fn new(
+ ident: Ident,
+ mocked_trait: Path,
+ item_methods: &[ImplItemMethod],
+ generics: Generics,
+ ) -> Self
{
let expectations_fields = item_methods
.iter()
@@ -60,6 +67,7 @@ impl Mock
.generics
.params
.iter()
+ .chain(generics.params.iter())
.filter_map(|generic_param| match generic_param {
GenericParam::Type(_) => Some(GenericArgument::Type(
Type::Tuple(create_unit_type_tuple()),
@@ -90,6 +98,7 @@ impl Mock
mocked_trait,
expectations_fields,
item_methods: item_methods.to_vec(),
+ generics,
}
}
}
@@ -103,21 +112,26 @@ impl ToTokens for Mock
mocked_trait,
expectations_fields,
item_methods,
+ generics,
} = self;
let expectations_field_idents = expectations_fields
.iter()
.map(|expectations_field| expectations_field.field_ident.clone());
- let mock_functions = item_methods
- .iter()
- .map(|item_method| create_mock_function(item_method.clone()));
+ let mock_functions = item_methods.iter().map(|item_method| {
+ create_mock_function(item_method.clone(), &generics.params)
+ });
let expect_functions = item_methods
.iter()
- .map(|item_method| create_expect_function(&self.ident, &item_method.clone()))
+ .map(|item_method| {
+ create_expect_function(&self.ident, &item_method.clone(), generics)
+ })
.collect::<Vec<_>>();
+ let (impl_generics, _, _) = generics.split_for_impl();
+
quote! {
pub struct #ident
{
@@ -138,7 +152,7 @@ impl ToTokens for Mock
#(#expect_functions)*
}
- impl #mocked_trait for #ident {
+ impl #impl_generics #mocked_trait for #ident {
#(
#mock_functions
)*
@@ -175,15 +189,24 @@ impl ToTokens for ExpectationsField
}
}
-fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
+fn create_mock_function(
+ item_method: ImplItemMethod,
+ generic_params: &Punctuated<GenericParam, Token![,]>,
+) -> ImplItemMethod
{
let func_ident = &item_method.sig.ident;
let type_param_idents = item_method
.sig
.generics
- .type_params()
- .map(|type_param| type_param.ident.clone())
+ .params
+ .clone()
+ .into_iter()
+ .chain(generic_params.clone())
+ .filter_map(|generic_param| match generic_param {
+ GenericParam::Type(type_param) => Some(type_param.ident),
+ _ => None,
+ })
.collect::<Vec<_>>();
let args = item_method
@@ -204,12 +227,6 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
let expectations_field = format_ident!("{func_ident}_expectations");
- let ids = quote! {
- let ids = vec![
- #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),*
- ];
- };
-
ImplItemMethod {
attrs: item_method.attrs,
vis: Visibility::Inherited,
@@ -217,7 +234,9 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
sig: item_method.sig.clone(),
block: parse2(quote! {
{
- #ids
+ let ids = vec![
+ #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),*
+ ];
let expectation = self
.#expectations_field
@@ -226,7 +245,7 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
"No expectation found for function ",
stringify!(#func_ident)
))
- .with_generic_params::<#(#type_param_idents),*>();
+ .with_generic_params::<#(#type_param_idents,)*>();
let Some(returning) = &expectation.returning else {
panic!(concat!(
@@ -243,11 +262,23 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
}
}
-fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplItemMethod
+fn create_expect_function(
+ mock: &Ident,
+ item_method: &ImplItemMethod,
+ generics: &Generics,
+) -> ImplItemMethod
{
+ let signature_generics = {
+ let mut sig_generics = item_method.sig.generics.clone();
+
+ sig_generics.params.extend(generics.params.clone());
+
+ sig_generics
+ };
+
let signature = Signature::new(
format_ident!("expect_{}", item_method.sig.ident),
- item_method.sig.generics.clone(),
+ signature_generics.clone(),
[FnArg::Receiver(Receiver {
attrs: vec![],
reference: Some((<Token![&]>::default(), None)),
@@ -263,7 +294,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte
create_expectation_ident(mock, &item_method.sig.ident),
Some(AngleBracketedGenericArguments::new(
WithColons::No,
- item_method.sig.generics.params.iter().map(|generic_param| {
+ signature_generics.params.iter().map(|generic_param| {
GenericArgument::from_generic_param(generic_param.clone())
}),
)),
@@ -272,9 +303,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte
))),
);
- let type_param_idents = item_method
- .sig
- .generics
+ let type_param_idents = signature_generics
.type_params()
.map(|type_param| type_param.ident.clone())
.collect::<Vec<_>>();