diff options
author | HampusM <hampus@hampusmat.com> | 2023-03-18 17:14:42 +0100 |
---|---|---|
committer | HampusM <hampus@hampusmat.com> | 2023-03-18 17:15:30 +0100 |
commit | c48271aef7e6b0819c497f302127c161845a83d7 (patch) | |
tree | a18d7b5fc8e017b4b7e0917a55534b28a01fe57d /macros/src/expectation.rs | |
parent | 2ca8017deebe7bfe5aac368aead777a2c4910ca2 (diff) |
refactor: rewrite the mock macro as a procedural macro
Diffstat (limited to 'macros/src/expectation.rs')
-rw-r--r-- | macros/src/expectation.rs | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs new file mode 100644 index 0000000..ff3d192 --- /dev/null +++ b/macros/src/expectation.rs @@ -0,0 +1,376 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::{ + AngleBracketedGenericArguments, + Attribute, + BareFnArg, + Field, + Fields, + FieldsNamed, + FnArg, + GenericArgument, + GenericParam, + Generics, + ItemStruct, + Lifetime, + Path, + PathSegment, + Receiver, + ReturnType, + Token, + TraitItemMethod, + Type, + TypeBareFn, + TypePath, + TypeReference, + Visibility, +}; + +use crate::syn_ext::{ + AngleBracketedGenericArgumentsExt, + AttributeExt, + AttributeStyle, + BareFnArgExt, + GenericsExt, + IsMut, + LifetimeExt, + PathExt, + PathSegmentExt, + TypeBareFnExt, + TypePathExt, + TypeReferenceExt, + VisibilityExt, + WithColons, + WithLeadingColons, +}; +use crate::util::{create_path, create_unit_type_tuple}; + +pub struct Expectation +{ + ident: Ident, + generics: Generics, + receiver: Option<Receiver>, + mock: Ident, + arg_types: Vec<Type>, + return_type: ReturnType, + phantom_fields: Vec<PhantomField>, +} + +impl Expectation +{ + pub fn new(mock: &Ident, item_method: &TraitItemMethod) -> Self + { + let ident = create_expectation_ident(mock, &item_method.sig.ident); + + let phantom_fields = + Self::create_phantom_fields(&item_method.sig.generics.params); + + let receiver = + item_method + .sig + .inputs + .first() + .and_then(|first_arg| match first_arg { + FnArg::Receiver(receiver) => Some(receiver.clone()), + FnArg::Typed(_) => None, + }); + + let arg_types = item_method + .sig + .inputs + .iter() + .filter_map(|arg| match arg { + FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()), + FnArg::Receiver(_) => None, + }) + .collect::<Vec<_>>(); + + let return_type = item_method.sig.output.clone(); + + Self { + ident, + generics: item_method.sig.generics.clone(), + receiver, + mock: mock.clone(), + arg_types, + return_type, + phantom_fields, + } + } + + fn create_phantom_fields( + generic_params: &Punctuated<GenericParam, Token![,]>, + ) -> Vec<PhantomField> + { + generic_params + .iter() + .filter_map(|generic_param| match generic_param { + GenericParam::Type(type_param) => { + let type_param_ident = &type_param.ident; + + let field_ident = create_phantom_field_ident( + type_param_ident, + &PhantomFieldKind::Type, + ); + + let ty = create_phantom_data_type_path([GenericArgument::Type( + Type::Path(TypePath::new(Path::new( + WithLeadingColons::No, + [PathSegment::new(type_param_ident.clone(), None)], + ))), + )]); + + Some(PhantomField { + field: field_ident, + type_path: ty, + }) + } + GenericParam::Lifetime(lifetime_param) => { + let lifetime = &lifetime_param.lifetime; + + let field_ident = create_phantom_field_ident( + &lifetime.ident, + &PhantomFieldKind::Lifetime, + ); + + let ty = create_phantom_data_type_path([GenericArgument::Type( + Type::Reference(TypeReference::new( + Some(lifetime.clone()), + IsMut::No, + Type::Tuple(create_unit_type_tuple()), + )), + )]); + + Some(PhantomField { + field: field_ident, + type_path: ty, + }) + } + GenericParam::Const(_) => None, + }) + .collect() + } +} + +impl ToTokens for Expectation +{ + fn to_tokens(&self, tokens: &mut TokenStream) + { + let generic_params = &self.generics.params; + + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + + let bogus_generics = create_bogus_generics(generic_params); + + let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone()); + + let ident = &self.ident; + let phantom_fields = &self.phantom_fields; + + let returning_fn = Type::BareFn(TypeBareFn::new( + opt_self_type + .iter() + .chain(self.arg_types.iter()) + .map(|ty| BareFnArg::new(ty.clone())), + self.return_type.clone(), + )); + + let expectation_struct = ItemStruct { + attrs: vec![Attribute::new( + AttributeStyle::Outer, + create_path!(allow), + quote! { (non_camel_case_types, non_snake_case) }, + )], + vis: Visibility::new_pub_crate(), + struct_token: <Token![struct]>::default(), + ident: self.ident.clone(), + generics: self.generics.clone().without_where_clause(), + fields: Fields::Named(FieldsNamed { + brace_token: Brace::default(), + named: [Field { + attrs: vec![], + vis: Visibility::Inherited, + ident: Some(format_ident!("returning")), + colon_token: Some(<Token![:]>::default()), + ty: Type::Path(TypePath::new(Path::new( + WithLeadingColons::No, + [PathSegment::new( + format_ident!("Option"), + Some(AngleBracketedGenericArguments::new( + WithColons::No, + [GenericArgument::Type(returning_fn.clone())], + )), + )], + ))), + }] + .into_iter() + .chain(phantom_fields.iter().map(|phantom_field| Field { + attrs: vec![], + vis: Visibility::Inherited, + ident: Some(phantom_field.field.clone()), + colon_token: Some(<Token![:]>::default()), + ty: Type::Path(phantom_field.type_path.clone()), + })) + .collect(), + }), + semi_token: None, + }; + + quote! { + #expectation_struct + + impl #impl_generics #ident #ty_generics #where_clause + { + fn new() -> Self { + Self { + returning: None, + #(#phantom_fields),* + } + } + + #[allow(unused)] + pub fn returning( + &mut self, + func: #returning_fn + ) -> &mut Self + { + self.returning = Some(func); + + self + } + + #[allow(unused)] + fn strip_generic_params( + self, + ) -> #ident<#(#bogus_generics),*> + { + unsafe { std::mem::transmute(self) } + } + } + + impl #ident<#(#bogus_generics),*> { + #[allow(unused)] + fn with_generic_params<#generic_params>( + &self, + ) -> &#ident #ty_generics + { + // SAFETY: self is a pointer to a sane place, Rustc guarantees that + // by it being a reference. The generic parameters doesn't affect + // the size of self in any way, as they are only used in the function + // pointer field "returning" + unsafe { &*(self as *const Self).cast() } + } + + #[allow(unused)] + fn with_generic_params_mut<#generic_params>( + &mut self, + ) -> &mut #ident #ty_generics + { + // SAFETY: self is a pointer to a sane place, Rustc guarantees that + // by it being a reference. The generic parameters doesn't affect + // the size of self in any way, as they are only used in the function + // pointer field "returning" + unsafe { &mut *(self as *mut Self).cast() } + } + } + } + .to_tokens(tokens); + } +} + +pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident +{ + format_ident!("{mock}Expectation_{method}") +} + +struct PhantomField +{ + field: Ident, + type_path: TypePath, +} + +impl ToTokens for PhantomField +{ + fn to_tokens(&self, tokens: &mut TokenStream) + { + self.field.to_tokens(tokens); + + <Token![:]>::default().to_tokens(tokens); + + self.type_path.to_tokens(tokens); + } +} + +fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident +{ + match kind { + PhantomFieldKind::Type => format_ident!("{ident}_phantom"), + PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"), + } +} + +enum PhantomFieldKind +{ + Type, + Lifetime, +} + +fn create_phantom_data_type_path( + generic_args: impl IntoIterator<Item = GenericArgument>, +) -> TypePath +{ + TypePath::new(Path::new( + WithLeadingColons::Yes, + [ + PathSegment::new(format_ident!("std"), None), + PathSegment::new(format_ident!("marker"), None), + PathSegment::new( + format_ident!("PhantomData"), + Some(AngleBracketedGenericArguments::new( + WithColons::Yes, + generic_args, + )), + ), + ], + )) +} + +fn create_bogus_generics( + generic_params: &Punctuated<GenericParam, Token![,]>, +) -> Vec<GenericArgument> +{ + generic_params + .iter() + .filter_map(|generic_param| match generic_param { + GenericParam::Type(_) => { + Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple()))) + } + GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime( + Lifetime::create(format_ident!("static")), + )), + GenericParam::Const(_) => None, + }) + .collect() +} + +fn receiver_to_mock_self_type(receiver: &Option<Receiver>, mock: Ident) -> Option<Type> +{ + receiver.as_ref().map(|receiver| { + let self_type = Type::Path(TypePath::new(Path::new( + WithLeadingColons::No, + [PathSegment::new(mock, None)], + ))); + + if let Some((_, lifetime)) = &receiver.reference { + return Type::Reference(TypeReference::new( + lifetime.clone(), + receiver.mutability.into(), + self_type, + )); + } + + self_type + }) +} |