use proc_macro2::{Ident, Span, TokenStream}; use proc_macro_error::{abort_call_site, ResultExt}; use quote::{format_ident, quote, ToTokens}; use syn::punctuated::Punctuated; use syn::{ parse2, AngleBracketedGenericArguments, FnArg, GenericArgument, GenericParam, Generics, ImplItemMethod, Lifetime, Pat, Path, PathSegment, Receiver, ReturnType, Signature, Token, Type, TypePath, TypeReference, Visibility, }; use crate::expectation::create_expectation_ident; use crate::syn_ext::{ AngleBracketedGenericArgumentsExt, GenericArgumentExt, IsMut, PathExt, PathSegmentExt, ReturnTypeExt, SignatureExt, TypePathExt, TypeReferenceExt, VisibilityExt, WithColons, WithLeadingColons, }; use crate::util::create_unit_type_tuple; pub struct Mock { ident: Ident, mocked_trait: Path, expectations_fields: Vec, item_methods: Vec, generics: Generics, } impl Mock { pub fn new( ident: Ident, mocked_trait: Path, item_methods: &[ImplItemMethod], generics: Generics, ) -> Self { let expectations_fields = item_methods .iter() .map(|method_item| { let generic_args = method_item .sig .generics .params .iter() .chain(generics.params.iter()) .filter_map(|generic_param| match generic_param { GenericParam::Type(_) => Some(GenericArgument::Type( Type::Tuple(create_unit_type_tuple()), )), GenericParam::Lifetime(_) => { Some(GenericArgument::Lifetime(Lifetime { apostrophe: Span::call_site(), ident: format_ident!("static"), })) } GenericParam::Const(_) => None, }) .collect::>(); ExpectationsField { field_ident: format_ident!("{}_expectations", method_item.sig.ident), expectation_ident: create_expectation_ident( &ident, &method_item.sig.ident, ), generic_args, } }) .collect::>(); Self { ident, mocked_trait, expectations_fields, item_methods: item_methods.to_vec(), generics, } } } impl ToTokens for Mock { fn to_tokens(&self, tokens: &mut TokenStream) { let Self { ident, mocked_trait, expectations_fields, item_methods, generics, } = self; let expectations_field_idents = expectations_fields .iter() .map(|expectations_field| expectations_field.field_ident.clone()); let mock_functions = item_methods.iter().map(|item_method| { create_mock_function(item_method.clone(), &generics.params) }); let expect_functions = item_methods .iter() .map(|item_method| { create_expect_function(&self.ident, &item_method.clone(), generics) }) .collect::>(); let (impl_generics, _, _) = generics.split_for_impl(); quote! { pub struct #ident { #(#expectations_fields),* } impl #ident { pub fn new() -> Self { Self { #( #expectations_field_idents: ::std::collections::HashMap::new() ),* } } #(#expect_functions)* } impl #impl_generics #mocked_trait for #ident { #( #mock_functions )* } } .to_tokens(tokens); } } struct ExpectationsField { field_ident: Ident, expectation_ident: Ident, generic_args: Vec, } impl ToTokens for ExpectationsField { fn to_tokens(&self, tokens: &mut TokenStream) { let Self { field_ident, expectation_ident, generic_args, } = self; quote! { #field_ident: ::std::collections::HashMap< Vec<::ridicule::__private::type_id::TypeID>, ::std::collections::VecDeque<#expectation_ident<#(#generic_args),*>> > } .to_tokens(tokens); } } fn create_mock_function( item_method: ImplItemMethod, generic_params: &Punctuated, ) -> ImplItemMethod { let func_ident = &item_method.sig.ident; let type_param_idents = item_method .sig .generics .params .clone() .into_iter() .chain(generic_params.clone()) .filter_map(|generic_param| match generic_param { GenericParam::Type(type_param) => Some(type_param.ident), _ => None, }) .collect::>(); let args = item_method .sig .inputs .iter() .map(|fn_arg| match fn_arg { FnArg::Receiver(_) => format_ident!("self"), FnArg::Typed(pat_type) => { let Pat::Ident(pat_ident) = pat_type.pat.as_ref() else { abort_call_site!("Unsupport argument pattern"); }; pat_ident.ident.clone() } }) .collect::>(); let expectations_field = format_ident!("{func_ident}_expectations"); ImplItemMethod { attrs: item_method.attrs, vis: Visibility::Inherited, defaultness: None, sig: item_method.sig.clone(), block: parse2(quote! { { let ids = vec![ #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),* ]; let func_expectations = self .#expectations_field .get(&ids) .expect(concat!( "No expectation found for function ", stringify!(#func_ident) )); let expectation = func_expectations .iter() .skip_while(|expectation| expectation.is_exhausted()) .next() .expect(concat!( "No expectation found for function ", stringify!(#func_ident) )) .with_generic_params::<#(#type_param_idents,)*>(); (expectation.get_returning())(#(#args),*) } }) .unwrap_or_abort(), } } fn create_expect_function( mock: &Ident, item_method: &ImplItemMethod, generics: &Generics, ) -> ImplItemMethod { let signature_generics = { let mut sig_generics = item_method.sig.generics.clone(); sig_generics.params.extend(generics.params.clone()); sig_generics }; let signature = Signature::new( format_ident!("expect_{}", item_method.sig.ident), signature_generics.clone(), [FnArg::Receiver(Receiver { attrs: vec![], reference: Some((::default(), None)), mutability: Some(::default()), self_token: ::default(), })], ReturnType::new(Type::Reference(TypeReference::new( None, IsMut::Yes, Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new( create_expectation_ident(mock, &item_method.sig.ident), Some(AngleBracketedGenericArguments::new( WithColons::No, signature_generics.params.iter().map(|generic_param| { GenericArgument::from_generic_param(generic_param.clone()) }), )), )], ))), ))), ); let type_param_idents = signature_generics .type_params() .map(|type_param| type_param.ident.clone()) .collect::>(); let expectation = create_expectation_ident(mock, &item_method.sig.ident); let expectations_field = format_ident!("{}_expectations", item_method.sig.ident); ImplItemMethod { attrs: item_method.attrs.clone(), vis: Visibility::new_pub_crate(), defaultness: None, sig: signature, block: parse2(quote! {{ let ids = vec![ #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),* ]; let expectation = #expectation::<#(#type_param_idents),*>::new() .strip_generic_params(); let func_expectations = self .#expectations_field .entry(ids.clone()) .or_insert_with(::std::collections::VecDeque::new); func_expectations.push_back(expectation); func_expectations .back_mut() .unwrap() .with_generic_params_mut() }}) .unwrap_or_abort(), } }