use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, ToTokens}; use syn::punctuated::Punctuated; use syn::token::Brace; use syn::{ AngleBracketedGenericArguments, Attribute, BareFnArg, Field, Fields, FieldsNamed, FnArg, GenericArgument, GenericParam, Generics, ItemStruct, Lifetime, Path, PathSegment, Receiver, ReturnType, Token, TraitItemMethod, Type, TypeBareFn, TypePath, TypeReference, Visibility, }; use crate::syn_ext::{ AngleBracketedGenericArgumentsExt, AttributeExt, AttributeStyle, BareFnArgExt, GenericsExt, IsMut, LifetimeExt, PathExt, PathSegmentExt, TypeBareFnExt, TypePathExt, TypeReferenceExt, VisibilityExt, WithColons, WithLeadingColons, }; use crate::util::{create_path, create_unit_type_tuple}; pub struct Expectation { ident: Ident, generics: Generics, receiver: Option, mock: Ident, arg_types: Vec, return_type: ReturnType, phantom_fields: Vec, } impl Expectation { pub fn new(mock: &Ident, item_method: &TraitItemMethod) -> Self { let ident = create_expectation_ident(mock, &item_method.sig.ident); let phantom_fields = Self::create_phantom_fields(&item_method.sig.generics.params); let receiver = item_method .sig .inputs .first() .and_then(|first_arg| match first_arg { FnArg::Receiver(receiver) => Some(receiver.clone()), FnArg::Typed(_) => None, }); let arg_types = item_method .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()), FnArg::Receiver(_) => None, }) .collect::>(); let return_type = item_method.sig.output.clone(); Self { ident, generics: item_method.sig.generics.clone(), receiver, mock: mock.clone(), arg_types, return_type, phantom_fields, } } fn create_phantom_fields( generic_params: &Punctuated, ) -> Vec { generic_params .iter() .filter_map(|generic_param| match generic_param { GenericParam::Type(type_param) => { let type_param_ident = &type_param.ident; let field_ident = create_phantom_field_ident( type_param_ident, &PhantomFieldKind::Type, ); let ty = create_phantom_data_type_path([GenericArgument::Type( Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new(type_param_ident.clone(), None)], ))), )]); Some(PhantomField { field: field_ident, type_path: ty, }) } GenericParam::Lifetime(lifetime_param) => { let lifetime = &lifetime_param.lifetime; let field_ident = create_phantom_field_ident( &lifetime.ident, &PhantomFieldKind::Lifetime, ); let ty = create_phantom_data_type_path([GenericArgument::Type( Type::Reference(TypeReference::new( Some(lifetime.clone()), IsMut::No, Type::Tuple(create_unit_type_tuple()), )), )]); Some(PhantomField { field: field_ident, type_path: ty, }) } GenericParam::Const(_) => None, }) .collect() } } impl ToTokens for Expectation { fn to_tokens(&self, tokens: &mut TokenStream) { let generic_params = &self.generics.params; let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let bogus_generics = create_bogus_generics(generic_params); let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone()); let ident = &self.ident; let phantom_fields = &self.phantom_fields; let returning_fn = Type::BareFn(TypeBareFn::new( opt_self_type .iter() .chain(self.arg_types.iter()) .map(|ty| BareFnArg::new(ty.clone())), self.return_type.clone(), )); let expectation_struct = ItemStruct { attrs: vec![Attribute::new( AttributeStyle::Outer, create_path!(allow), quote! { (non_camel_case_types, non_snake_case) }, )], vis: Visibility::new_pub_crate(), struct_token: ::default(), ident: self.ident.clone(), generics: self.generics.clone().without_where_clause(), fields: Fields::Named(FieldsNamed { brace_token: Brace::default(), named: [Field { attrs: vec![], vis: Visibility::Inherited, ident: Some(format_ident!("returning")), colon_token: Some(::default()), ty: Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new( format_ident!("Option"), Some(AngleBracketedGenericArguments::new( WithColons::No, [GenericArgument::Type(returning_fn.clone())], )), )], ))), }] .into_iter() .chain(phantom_fields.iter().map(|phantom_field| Field { attrs: vec![], vis: Visibility::Inherited, ident: Some(phantom_field.field.clone()), colon_token: Some(::default()), ty: Type::Path(phantom_field.type_path.clone()), })) .collect(), }), semi_token: None, }; quote! { #expectation_struct impl #impl_generics #ident #ty_generics #where_clause { fn new() -> Self { Self { returning: None, #(#phantom_fields),* } } #[allow(unused)] pub fn returning( &mut self, func: #returning_fn ) -> &mut Self { self.returning = Some(func); self } #[allow(unused)] fn strip_generic_params( self, ) -> #ident<#(#bogus_generics),*> { unsafe { std::mem::transmute(self) } } } impl #ident<#(#bogus_generics),*> { #[allow(unused)] fn with_generic_params<#generic_params>( &self, ) -> &#ident #ty_generics { // SAFETY: self is a pointer to a sane place, Rustc guarantees that // by it being a reference. The generic parameters doesn't affect // the size of self in any way, as they are only used in the function // pointer field "returning" unsafe { &*(self as *const Self).cast() } } #[allow(unused)] fn with_generic_params_mut<#generic_params>( &mut self, ) -> &mut #ident #ty_generics { // SAFETY: self is a pointer to a sane place, Rustc guarantees that // by it being a reference. The generic parameters doesn't affect // the size of self in any way, as they are only used in the function // pointer field "returning" unsafe { &mut *(self as *mut Self).cast() } } } } .to_tokens(tokens); } } pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident { format_ident!("{mock}Expectation_{method}") } struct PhantomField { field: Ident, type_path: TypePath, } impl ToTokens for PhantomField { fn to_tokens(&self, tokens: &mut TokenStream) { self.field.to_tokens(tokens); ::default().to_tokens(tokens); self.type_path.to_tokens(tokens); } } fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident { match kind { PhantomFieldKind::Type => format_ident!("{ident}_phantom"), PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"), } } enum PhantomFieldKind { Type, Lifetime, } fn create_phantom_data_type_path( generic_args: impl IntoIterator, ) -> TypePath { TypePath::new(Path::new( WithLeadingColons::Yes, [ PathSegment::new(format_ident!("std"), None), PathSegment::new(format_ident!("marker"), None), PathSegment::new( format_ident!("PhantomData"), Some(AngleBracketedGenericArguments::new( WithColons::Yes, generic_args, )), ), ], )) } fn create_bogus_generics( generic_params: &Punctuated, ) -> Vec { generic_params .iter() .filter_map(|generic_param| match generic_param { GenericParam::Type(_) => { Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple()))) } GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime( Lifetime::create(format_ident!("static")), )), GenericParam::Const(_) => None, }) .collect() } fn receiver_to_mock_self_type(receiver: &Option, mock: Ident) -> Option { receiver.as_ref().map(|receiver| { let self_type = Type::Path(TypePath::new(Path::new( WithLeadingColons::No, [PathSegment::new(mock, None)], ))); if let Some((_, lifetime)) = &receiver.reference { return Type::Reference(TypeReference::new( lifetime.clone(), receiver.mutability.into(), self_type, )); } self_type }) }