From 2d964b39da09ad82eccf09abdea73967bbff76f2 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 18 Mar 2023 21:26:54 +0100 Subject: feat: add support for generic traits --- macros/src/expectation.rs | 61 ++++++++++++++++++++++++++++--------- macros/src/lib.rs | 17 ++++++++--- macros/src/mock.rs | 77 ++++++++++++++++++++++++++++++++--------------- 3 files changed, 112 insertions(+), 43 deletions(-) (limited to 'macros/src') diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs index 4fc1451..d35ef97 100644 --- a/macros/src/expectation.rs +++ b/macros/src/expectation.rs @@ -50,7 +50,8 @@ use crate::util::{create_path, create_unit_type_tuple}; pub struct Expectation { ident: Ident, - generics: Generics, + method_generics: Generics, + generic_params: Punctuated, receiver: Option, mock: Ident, arg_types: Vec, @@ -60,12 +61,24 @@ pub struct Expectation impl Expectation { - pub fn new(mock: &Ident, item_method: &ImplItemMethod) -> Self + pub fn new( + mock: &Ident, + item_method: &ImplItemMethod, + generic_params: Punctuated, + ) -> Self { let ident = create_expectation_ident(mock, &item_method.sig.ident); - let phantom_fields = - Self::create_phantom_fields(&item_method.sig.generics.params); + let phantom_fields = Self::create_phantom_fields( + &item_method + .sig + .generics + .params + .clone() + .into_iter() + .chain(generic_params.clone()) + .collect(), + ); let receiver = item_method @@ -91,7 +104,8 @@ impl Expectation Self { ident, - generics: item_method.sig.generics.clone(), + method_generics: item_method.sig.generics.clone(), + generic_params, receiver, mock: mock.clone(), arg_types, @@ -158,9 +172,17 @@ impl ToTokens for Expectation { fn to_tokens(&self, tokens: &mut TokenStream) { - let generic_params = &self.generics.params; + let generics = { + let mut generics = self.method_generics.clone(); - let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + generics.params.extend(self.generic_params.clone()); + + generics + }; + + let generic_params = &generics.params; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let bogus_generics = create_bogus_generics(generic_params); @@ -186,7 +208,7 @@ impl ToTokens for Expectation vis: Visibility::new_pub_crate(), struct_token: ::default(), ident: self.ident.clone(), - generics: self.generics.clone().without_where_clause(), + generics: generics.clone().without_where_clause(), fields: Fields::Named(FieldsNamed { brace_token: Brace::default(), named: [Field { @@ -206,13 +228,7 @@ impl ToTokens for Expectation ))), }] .into_iter() - .chain(phantom_fields.iter().map(|phantom_field| Field { - attrs: vec![], - vis: Visibility::Inherited, - ident: Some(phantom_field.field.clone()), - colon_token: Some(::default()), - ty: Type::Path(phantom_field.type_path.clone()), - })) + .chain(phantom_fields.iter().cloned().map(Field::from)) .collect(), }), semi_token: None, @@ -285,6 +301,7 @@ pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident format_ident!("{mock}Expectation_{method}") } +#[derive(Clone)] struct PhantomField { field: Ident, @@ -303,6 +320,20 @@ impl ToTokens for PhantomField } } +impl From for Field +{ + fn from(phantom_field: PhantomField) -> Self + { + Self { + attrs: vec![], + vis: Visibility::Inherited, + ident: Some(phantom_field.field.clone()), + colon_token: Some(::default()), + ty: Type::Path(phantom_field.type_path), + } + } +} + fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident { match kind { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 36c6ad7..8106a8c 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -34,11 +34,20 @@ pub fn mock(input_stream: TokenStream) -> TokenStream }) .collect::>(); - let mock = Mock::new(mock_ident.clone(), input.mocked_trait, &method_items); + let mock = Mock::new( + mock_ident.clone(), + input.mocked_trait, + &method_items, + input.item_impl.generics.clone(), + ); - let expectations = method_items - .iter() - .map(|item_method| Expectation::new(&mock_ident, item_method)); + let expectations = method_items.iter().map(|item_method| { + Expectation::new( + &mock_ident, + item_method, + input.item_impl.generics.params.clone(), + ) + }); quote! { mod #mock_mod_ident { diff --git a/macros/src/mock.rs b/macros/src/mock.rs index 1b6dd67..8828b17 100644 --- a/macros/src/mock.rs +++ b/macros/src/mock.rs @@ -1,12 +1,14 @@ use proc_macro2::{Ident, Span, TokenStream}; use proc_macro_error::{abort_call_site, ResultExt}; use quote::{format_ident, quote, ToTokens}; +use syn::punctuated::Punctuated; use syn::{ parse2, AngleBracketedGenericArguments, FnArg, GenericArgument, GenericParam, + Generics, ImplItemMethod, Lifetime, Pat, @@ -45,12 +47,17 @@ pub struct Mock mocked_trait: Path, expectations_fields: Vec, item_methods: Vec, + generics: Generics, } impl Mock { - pub fn new(ident: Ident, mocked_trait: Path, item_methods: &[ImplItemMethod]) - -> Self + pub fn new( + ident: Ident, + mocked_trait: Path, + item_methods: &[ImplItemMethod], + generics: Generics, + ) -> Self { let expectations_fields = item_methods .iter() @@ -60,6 +67,7 @@ impl Mock .generics .params .iter() + .chain(generics.params.iter()) .filter_map(|generic_param| match generic_param { GenericParam::Type(_) => Some(GenericArgument::Type( Type::Tuple(create_unit_type_tuple()), @@ -90,6 +98,7 @@ impl Mock mocked_trait, expectations_fields, item_methods: item_methods.to_vec(), + generics, } } } @@ -103,21 +112,26 @@ impl ToTokens for Mock mocked_trait, expectations_fields, item_methods, + generics, } = self; let expectations_field_idents = expectations_fields .iter() .map(|expectations_field| expectations_field.field_ident.clone()); - let mock_functions = item_methods - .iter() - .map(|item_method| create_mock_function(item_method.clone())); + let mock_functions = item_methods.iter().map(|item_method| { + create_mock_function(item_method.clone(), &generics.params) + }); let expect_functions = item_methods .iter() - .map(|item_method| create_expect_function(&self.ident, &item_method.clone())) + .map(|item_method| { + create_expect_function(&self.ident, &item_method.clone(), generics) + }) .collect::>(); + let (impl_generics, _, _) = generics.split_for_impl(); + quote! { pub struct #ident { @@ -138,7 +152,7 @@ impl ToTokens for Mock #(#expect_functions)* } - impl #mocked_trait for #ident { + impl #impl_generics #mocked_trait for #ident { #( #mock_functions )* @@ -175,15 +189,24 @@ impl ToTokens for ExpectationsField } } -fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod +fn create_mock_function( + item_method: ImplItemMethod, + generic_params: &Punctuated, +) -> ImplItemMethod { let func_ident = &item_method.sig.ident; let type_param_idents = item_method .sig .generics - .type_params() - .map(|type_param| type_param.ident.clone()) + .params + .clone() + .into_iter() + .chain(generic_params.clone()) + .filter_map(|generic_param| match generic_param { + GenericParam::Type(type_param) => Some(type_param.ident), + _ => None, + }) .collect::>(); let args = item_method @@ -204,12 +227,6 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod let expectations_field = format_ident!("{func_ident}_expectations"); - let ids = quote! { - let ids = vec![ - #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),* - ]; - }; - ImplItemMethod { attrs: item_method.attrs, vis: Visibility::Inherited, @@ -217,7 +234,9 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod sig: item_method.sig.clone(), block: parse2(quote! { { - #ids + let ids = vec![ + #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),* + ]; let expectation = self .#expectations_field @@ -226,7 +245,7 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod "No expectation found for function ", stringify!(#func_ident) )) - .with_generic_params::<#(#type_param_idents),*>(); + .with_generic_params::<#(#type_param_idents,)*>(); let Some(returning) = &expectation.returning else { panic!(concat!( @@ -243,11 +262,23 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod } } -fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplItemMethod +fn create_expect_function( + mock: &Ident, + item_method: &ImplItemMethod, + generics: &Generics, +) -> ImplItemMethod { + let signature_generics = { + let mut sig_generics = item_method.sig.generics.clone(); + + sig_generics.params.extend(generics.params.clone()); + + sig_generics + }; + let signature = Signature::new( format_ident!("expect_{}", item_method.sig.ident), - item_method.sig.generics.clone(), + signature_generics.clone(), [FnArg::Receiver(Receiver { attrs: vec![], reference: Some((::default(), None)), @@ -263,7 +294,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte create_expectation_ident(mock, &item_method.sig.ident), Some(AngleBracketedGenericArguments::new( WithColons::No, - item_method.sig.generics.params.iter().map(|generic_param| { + signature_generics.params.iter().map(|generic_param| { GenericArgument::from_generic_param(generic_param.clone()) }), )), @@ -272,9 +303,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte ))), ); - let type_param_idents = item_method - .sig - .generics + let type_param_idents = signature_generics .type_params() .map(|type_param| type_param.ident.clone()) .collect::>(); -- cgit v1.2.3-18-g5258