use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, ToTokens}; use syn::punctuated::Punctuated; use syn::token::Brace; use syn::{ AngleBracketedGenericArguments, Attribute, BareFnArg, Field, Fields, FieldsNamed, FnArg, GenericArgument, GenericParam, Generics, ImplItemMethod, ItemStruct, Lifetime, Path, PathSegment, Receiver, ReturnType, Token, Type, TypeBareFn, TypePath, TypeReference, Visibility, }; use crate::syn_ext::{ AngleBracketedGenericArgumentsExt, AttributeExt, AttributeStyle, BareFnArgExt, GenericsExt, IsMut, LifetimeExt, PathExt, PathSegmentExt, TypeBareFnExt, TypePathExt, TypeReferenceExt, VisibilityExt, WithColons, WithLeadingColons, }; use crate::util::{create_path, create_unit_type_tuple}; pub struct Expectation { ident: Ident, method_ident: Ident, method_generics: Generics, generic_params: Punctuated, receiver: Option, mock: Ident, arg_types: Vec, return_type: ReturnType, phantom_fields: Vec, } impl Expectation { pub fn new( mock: &Ident, item_method: &ImplItemMethod, generic_params: Punctuated, ) -> Self { let ident = create_expectation_ident(mock, &item_method.sig.ident); let phantom_fields = Self::create_phantom_fields( &item_method .sig .generics .params .clone() .into_iter() .chain(generic_params.clone()) .collect(), ); let receiver = item_method .sig .inputs .first() .and_then(|first_arg| match first_arg { FnArg::Receiver(receiver) => Some(receiver.clone()), FnArg::Typed(_) => None, }); let arg_types = item_method .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()), FnArg::Receiver(_) => None, }) .collect::>(); let return_type = item_method.sig.output.clone(); Self { ident, method_ident: item_method.sig.ident.clone(), method_generics: item_method.sig.generics.clone(), generic_params, receiver, mock: mock.clone(), arg_types, return_type, phantom_fields, } } fn create_phantom_fields( generic_params: &Punctuated, ) -> Vec { generic_params .iter() .filter_map(|generic_param| match generic_param { GenericParam::Type(type_param) => { let type_param_ident = &type_param.ident; let field_ident = create_phantom_field_ident( type_param_ident, &PhantomFieldKind::Type, ); let ty = create_phantom_data_type_path([GenericArgument::Type( Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new(type_param_ident.clone(), None)], ))), )]); Some(PhantomField { field: field_ident, type_path: ty, }) } GenericParam::Lifetime(lifetime_param) => { let lifetime = &lifetime_param.lifetime; let field_ident = create_phantom_field_ident( &lifetime.ident, &PhantomFieldKind::Lifetime, ); let ty = create_phantom_data_type_path([GenericArgument::Type( Type::Reference(TypeReference::new( Some(lifetime.clone()), IsMut::No, Type::Tuple(create_unit_type_tuple()), )), )]); Some(PhantomField { field: field_ident, type_path: ty, }) } GenericParam::Const(_) => None, }) .collect() } fn create_struct( ident: Ident, generics: Generics, phantom_fields: &[PhantomField], returning_fn: &Type, ) -> ItemStruct { ItemStruct { attrs: vec![Attribute::new( AttributeStyle::Outer, create_path!(allow), quote! { (non_camel_case_types, non_snake_case) }, )], vis: Visibility::new_pub_crate(), struct_token: ::default(), ident, generics: generics.strip_where_clause_and_bounds(), fields: Fields::Named(FieldsNamed { brace_token: Brace::default(), named: [ Field { attrs: vec![], vis: Visibility::Inherited, ident: Some(format_ident!("returning")), colon_token: Some(::default()), ty: Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new( format_ident!("Option"), Some(AngleBracketedGenericArguments::new( WithColons::No, [GenericArgument::Type(returning_fn.clone())], )), )], ))), }, Field { attrs: vec![], vis: Visibility::Inherited, ident: Some(format_ident!("call_cnt")), colon_token: Some(::default()), ty: Type::Path(TypePath::new(create_path!( ::std::sync::atomic::AtomicU32 ))), }, Field { attrs: vec![], vis: Visibility::Inherited, ident: Some(format_ident!("call_cnt_expectation")), colon_token: Some(::default()), ty: Type::Path(TypePath::new(create_path!( ::ridicule::__private::CallCountExpectation ))), }, ] .into_iter() .chain(phantom_fields.iter().cloned().map(Field::from)) .collect(), }), semi_token: None, } } } impl ToTokens for Expectation { #[allow(clippy::too_many_lines)] fn to_tokens(&self, tokens: &mut TokenStream) { let generics = { let mut generics = self.method_generics.clone(); generics.params.extend(self.generic_params.clone()); generics }; let generic_params = &generics.params; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let bogus_generics = create_bogus_generics(generic_params); let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone()); let ident = &self.ident; let phantom_fields = &self.phantom_fields; let returning_fn = Type::BareFn(TypeBareFn::new( opt_self_type .iter() .chain(self.arg_types.iter()) .map(|ty| BareFnArg::new(ty.clone())), self.return_type.clone(), )); let method_ident = &self.method_ident; let expectation_struct = Self::create_struct( self.ident.clone(), generics.clone(), phantom_fields, &returning_fn, ); let boundless_generics = generics.clone().strip_where_clause_and_bounds(); let (boundless_impl_generics, _, _) = boundless_generics.split_for_impl(); let do_strip_generic_params = if generic_params.is_empty() { quote! { self } } else { quote! { unsafe { std::mem::transmute(self) } } }; quote! { #expectation_struct impl #impl_generics #ident #ty_generics #where_clause { fn new() -> Self { Self { returning: None, call_cnt: ::std::sync::atomic::AtomicU32::new(0), call_cnt_expectation: ::ridicule::__private::CallCountExpectation::Unlimited, #(#phantom_fields),* } } #[allow(unused)] pub fn returning( &mut self, func: #returning_fn ) -> &mut Self { self.returning = Some(func); self } pub fn times(&mut self, cnt: u32) -> &mut Self { self.call_cnt_expectation = ::ridicule::__private::CallCountExpectation::Times(cnt); self } pub fn never(&mut self) -> &mut Self { self.call_cnt_expectation = ::ridicule::__private::CallCountExpectation::Never; self } #[allow(unused)] fn strip_generic_params( self, ) -> #ident<#(#bogus_generics),*> { #do_strip_generic_params } fn get_returning(&self) -> &#returning_fn { let Some(returning) = &self.returning else { panic!(concat!( "Expectation for function", stringify!(#method_ident), " is missing a function to call") ); }; if matches!( self.call_cnt_expectation, ::ridicule::__private::CallCountExpectation::Never ) { panic!( "Expected function {} to never be called", stringify!(#method_ident) ); } if let ::ridicule::__private::CallCountExpectation::Times( times ) = self.call_cnt_expectation { if times == self.call_cnt.load( ::std::sync::atomic::Ordering::Relaxed ) { panic!( concat!( "Expected function {} to be called {} times. Was ", "called {} times" ), stringify!(#method_ident), times, times + 1 ); } } self.call_cnt.fetch_add(1, std::sync::atomic::Ordering::Relaxed); returning } } impl #ident<#(#bogus_generics),*> { #[allow(unused)] fn with_generic_params<#generic_params>( &self, ) -> &#ident #ty_generics { // SAFETY: self is a pointer to a sane place, Rustc guarantees that // by it being a reference. The generic parameters doesn't affect // the size of self in any way, as they are only used in the function // pointer field "returning" unsafe { &*(self as *const Self).cast() } } #[allow(unused)] fn with_generic_params_mut<#generic_params>( &mut self, ) -> &mut #ident #ty_generics { // SAFETY: self is a pointer to a sane place, Rustc guarantees that // by it being a reference. The generic parameters doesn't affect // the size of self in any way, as they are only used in the function // pointer field "returning" unsafe { &mut *(self as *mut Self).cast() } } } impl #boundless_impl_generics #ident #ty_generics { fn is_exhausted(&self) -> bool { if let ::ridicule::__private::CallCountExpectation::Times(times) = self.call_cnt_expectation { if times == self.call_cnt.load( ::std::sync::atomic::Ordering::Relaxed ) { return true; } } false } } impl #boundless_impl_generics Drop for #ident #ty_generics { fn drop(&mut self) { let call_cnt = self.call_cnt.load(::std::sync::atomic::Ordering::Relaxed); if let ::ridicule::__private::CallCountExpectation::Times( times ) = self.call_cnt_expectation { if call_cnt != times { panic!( concat!( "Expected function {} to be called {} times. Was ", "called {} times" ), stringify!(#method_ident), times, call_cnt ); } } } } } .to_tokens(tokens); } } pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident { format_ident!("{mock}Expectation_{method}") } #[derive(Clone)] struct PhantomField { field: Ident, type_path: TypePath, } impl ToTokens for PhantomField { fn to_tokens(&self, tokens: &mut TokenStream) { self.field.to_tokens(tokens); ::default().to_tokens(tokens); self.type_path.to_tokens(tokens); } } impl From for Field { fn from(phantom_field: PhantomField) -> Self { Self { attrs: vec![], vis: Visibility::Inherited, ident: Some(phantom_field.field.clone()), colon_token: Some(::default()), ty: Type::Path(phantom_field.type_path), } } } fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident { match kind { PhantomFieldKind::Type => format_ident!("{ident}_phantom"), PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"), } } enum PhantomFieldKind { Type, Lifetime, } fn create_phantom_data_type_path( generic_args: impl IntoIterator, ) -> TypePath { TypePath::new(Path::new( WithLeadingColons::Yes, [ PathSegment::new(format_ident!("std"), None), PathSegment::new(format_ident!("marker"), None), PathSegment::new( format_ident!("PhantomData"), Some(AngleBracketedGenericArguments::new( WithColons::Yes, generic_args, )), ), ], )) } fn create_bogus_generics( generic_params: &Punctuated, ) -> Vec { generic_params .iter() .filter_map(|generic_param| match generic_param { GenericParam::Type(_) => { Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple()))) } GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime( Lifetime::create(format_ident!("static")), )), GenericParam::Const(_) => None, }) .collect() } fn receiver_to_mock_self_type(receiver: &Option, mock: Ident) -> Option { receiver.as_ref().map(|receiver| { let self_type = Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new(mock, None)], ))); if let Some((_, lifetime)) = &receiver.reference { return Type::Reference(TypeReference::new( lifetime.clone(), receiver.mutability.into(), self_type, )); } self_type }) }