summaryrefslogtreecommitdiff
path: root/macros/src/expectation.rs
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2023-03-18 17:14:42 +0100
committerHampusM <hampus@hampusmat.com>2023-03-18 17:15:30 +0100
commitc48271aef7e6b0819c497f302127c161845a83d7 (patch)
treea18d7b5fc8e017b4b7e0917a55534b28a01fe57d /macros/src/expectation.rs
parent2ca8017deebe7bfe5aac368aead777a2c4910ca2 (diff)
refactor: rewrite the mock macro as a procedural macro
Diffstat (limited to 'macros/src/expectation.rs')
-rw-r--r--macros/src/expectation.rs376
1 files changed, 376 insertions, 0 deletions
diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs
new file mode 100644
index 0000000..ff3d192
--- /dev/null
+++ b/macros/src/expectation.rs
@@ -0,0 +1,376 @@
+use proc_macro2::{Ident, TokenStream};
+use quote::{format_ident, quote, ToTokens};
+use syn::punctuated::Punctuated;
+use syn::token::Brace;
+use syn::{
+ AngleBracketedGenericArguments,
+ Attribute,
+ BareFnArg,
+ Field,
+ Fields,
+ FieldsNamed,
+ FnArg,
+ GenericArgument,
+ GenericParam,
+ Generics,
+ ItemStruct,
+ Lifetime,
+ Path,
+ PathSegment,
+ Receiver,
+ ReturnType,
+ Token,
+ TraitItemMethod,
+ Type,
+ TypeBareFn,
+ TypePath,
+ TypeReference,
+ Visibility,
+};
+
+use crate::syn_ext::{
+ AngleBracketedGenericArgumentsExt,
+ AttributeExt,
+ AttributeStyle,
+ BareFnArgExt,
+ GenericsExt,
+ IsMut,
+ LifetimeExt,
+ PathExt,
+ PathSegmentExt,
+ TypeBareFnExt,
+ TypePathExt,
+ TypeReferenceExt,
+ VisibilityExt,
+ WithColons,
+ WithLeadingColons,
+};
+use crate::util::{create_path, create_unit_type_tuple};
+
+pub struct Expectation
+{
+ ident: Ident,
+ generics: Generics,
+ receiver: Option<Receiver>,
+ mock: Ident,
+ arg_types: Vec<Type>,
+ return_type: ReturnType,
+ phantom_fields: Vec<PhantomField>,
+}
+
+impl Expectation
+{
+ pub fn new(mock: &Ident, item_method: &TraitItemMethod) -> Self
+ {
+ let ident = create_expectation_ident(mock, &item_method.sig.ident);
+
+ let phantom_fields =
+ Self::create_phantom_fields(&item_method.sig.generics.params);
+
+ let receiver =
+ item_method
+ .sig
+ .inputs
+ .first()
+ .and_then(|first_arg| match first_arg {
+ FnArg::Receiver(receiver) => Some(receiver.clone()),
+ FnArg::Typed(_) => None,
+ });
+
+ let arg_types = item_method
+ .sig
+ .inputs
+ .iter()
+ .filter_map(|arg| match arg {
+ FnArg::Typed(typed_arg) => Some(*typed_arg.ty.clone()),
+ FnArg::Receiver(_) => None,
+ })
+ .collect::<Vec<_>>();
+
+ let return_type = item_method.sig.output.clone();
+
+ Self {
+ ident,
+ generics: item_method.sig.generics.clone(),
+ receiver,
+ mock: mock.clone(),
+ arg_types,
+ return_type,
+ phantom_fields,
+ }
+ }
+
+ fn create_phantom_fields(
+ generic_params: &Punctuated<GenericParam, Token![,]>,
+ ) -> Vec<PhantomField>
+ {
+ generic_params
+ .iter()
+ .filter_map(|generic_param| match generic_param {
+ GenericParam::Type(type_param) => {
+ let type_param_ident = &type_param.ident;
+
+ let field_ident = create_phantom_field_ident(
+ type_param_ident,
+ &PhantomFieldKind::Type,
+ );
+
+ let ty = create_phantom_data_type_path([GenericArgument::Type(
+ Type::Path(TypePath::new(Path::new(
+ WithLeadingColons::No,
+ [PathSegment::new(type_param_ident.clone(), None)],
+ ))),
+ )]);
+
+ Some(PhantomField {
+ field: field_ident,
+ type_path: ty,
+ })
+ }
+ GenericParam::Lifetime(lifetime_param) => {
+ let lifetime = &lifetime_param.lifetime;
+
+ let field_ident = create_phantom_field_ident(
+ &lifetime.ident,
+ &PhantomFieldKind::Lifetime,
+ );
+
+ let ty = create_phantom_data_type_path([GenericArgument::Type(
+ Type::Reference(TypeReference::new(
+ Some(lifetime.clone()),
+ IsMut::No,
+ Type::Tuple(create_unit_type_tuple()),
+ )),
+ )]);
+
+ Some(PhantomField {
+ field: field_ident,
+ type_path: ty,
+ })
+ }
+ GenericParam::Const(_) => None,
+ })
+ .collect()
+ }
+}
+
+impl ToTokens for Expectation
+{
+ fn to_tokens(&self, tokens: &mut TokenStream)
+ {
+ let generic_params = &self.generics.params;
+
+ let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
+
+ let bogus_generics = create_bogus_generics(generic_params);
+
+ let opt_self_type = receiver_to_mock_self_type(&self.receiver, self.mock.clone());
+
+ let ident = &self.ident;
+ let phantom_fields = &self.phantom_fields;
+
+ let returning_fn = Type::BareFn(TypeBareFn::new(
+ opt_self_type
+ .iter()
+ .chain(self.arg_types.iter())
+ .map(|ty| BareFnArg::new(ty.clone())),
+ self.return_type.clone(),
+ ));
+
+ let expectation_struct = ItemStruct {
+ attrs: vec![Attribute::new(
+ AttributeStyle::Outer,
+ create_path!(allow),
+ quote! { (non_camel_case_types, non_snake_case) },
+ )],
+ vis: Visibility::new_pub_crate(),
+ struct_token: <Token![struct]>::default(),
+ ident: self.ident.clone(),
+ generics: self.generics.clone().without_where_clause(),
+ fields: Fields::Named(FieldsNamed {
+ brace_token: Brace::default(),
+ named: [Field {
+ attrs: vec![],
+ vis: Visibility::Inherited,
+ ident: Some(format_ident!("returning")),
+ colon_token: Some(<Token![:]>::default()),
+ ty: Type::Path(TypePath::new(Path::new(
+ WithLeadingColons::No,
+ [PathSegment::new(
+ format_ident!("Option"),
+ Some(AngleBracketedGenericArguments::new(
+ WithColons::No,
+ [GenericArgument::Type(returning_fn.clone())],
+ )),
+ )],
+ ))),
+ }]
+ .into_iter()
+ .chain(phantom_fields.iter().map(|phantom_field| Field {
+ attrs: vec![],
+ vis: Visibility::Inherited,
+ ident: Some(phantom_field.field.clone()),
+ colon_token: Some(<Token![:]>::default()),
+ ty: Type::Path(phantom_field.type_path.clone()),
+ }))
+ .collect(),
+ }),
+ semi_token: None,
+ };
+
+ quote! {
+ #expectation_struct
+
+ impl #impl_generics #ident #ty_generics #where_clause
+ {
+ fn new() -> Self {
+ Self {
+ returning: None,
+ #(#phantom_fields),*
+ }
+ }
+
+ #[allow(unused)]
+ pub fn returning(
+ &mut self,
+ func: #returning_fn
+ ) -> &mut Self
+ {
+ self.returning = Some(func);
+
+ self
+ }
+
+ #[allow(unused)]
+ fn strip_generic_params(
+ self,
+ ) -> #ident<#(#bogus_generics),*>
+ {
+ unsafe { std::mem::transmute(self) }
+ }
+ }
+
+ impl #ident<#(#bogus_generics),*> {
+ #[allow(unused)]
+ fn with_generic_params<#generic_params>(
+ &self,
+ ) -> &#ident #ty_generics
+ {
+ // SAFETY: self is a pointer to a sane place, Rustc guarantees that
+ // by it being a reference. The generic parameters doesn't affect
+ // the size of self in any way, as they are only used in the function
+ // pointer field "returning"
+ unsafe { &*(self as *const Self).cast() }
+ }
+
+ #[allow(unused)]
+ fn with_generic_params_mut<#generic_params>(
+ &mut self,
+ ) -> &mut #ident #ty_generics
+ {
+ // SAFETY: self is a pointer to a sane place, Rustc guarantees that
+ // by it being a reference. The generic parameters doesn't affect
+ // the size of self in any way, as they are only used in the function
+ // pointer field "returning"
+ unsafe { &mut *(self as *mut Self).cast() }
+ }
+ }
+ }
+ .to_tokens(tokens);
+ }
+}
+
+pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident
+{
+ format_ident!("{mock}Expectation_{method}")
+}
+
+struct PhantomField
+{
+ field: Ident,
+ type_path: TypePath,
+}
+
+impl ToTokens for PhantomField
+{
+ fn to_tokens(&self, tokens: &mut TokenStream)
+ {
+ self.field.to_tokens(tokens);
+
+ <Token![:]>::default().to_tokens(tokens);
+
+ self.type_path.to_tokens(tokens);
+ }
+}
+
+fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident
+{
+ match kind {
+ PhantomFieldKind::Type => format_ident!("{ident}_phantom"),
+ PhantomFieldKind::Lifetime => format_ident!("{ident}_lt_phantom"),
+ }
+}
+
+enum PhantomFieldKind
+{
+ Type,
+ Lifetime,
+}
+
+fn create_phantom_data_type_path(
+ generic_args: impl IntoIterator<Item = GenericArgument>,
+) -> TypePath
+{
+ TypePath::new(Path::new(
+ WithLeadingColons::Yes,
+ [
+ PathSegment::new(format_ident!("std"), None),
+ PathSegment::new(format_ident!("marker"), None),
+ PathSegment::new(
+ format_ident!("PhantomData"),
+ Some(AngleBracketedGenericArguments::new(
+ WithColons::Yes,
+ generic_args,
+ )),
+ ),
+ ],
+ ))
+}
+
+fn create_bogus_generics(
+ generic_params: &Punctuated<GenericParam, Token![,]>,
+) -> Vec<GenericArgument>
+{
+ generic_params
+ .iter()
+ .filter_map(|generic_param| match generic_param {
+ GenericParam::Type(_) => {
+ Some(GenericArgument::Type(Type::Tuple(create_unit_type_tuple())))
+ }
+ GenericParam::Lifetime(_) => Some(GenericArgument::Lifetime(
+ Lifetime::create(format_ident!("static")),
+ )),
+ GenericParam::Const(_) => None,
+ })
+ .collect()
+}
+
+fn receiver_to_mock_self_type(receiver: &Option<Receiver>, mock: Ident) -> Option<Type>
+{
+ receiver.as_ref().map(|receiver| {
+ let self_type = Type::Path(TypePath::new(Path::new(
+ WithLeadingColons::No,
+ [PathSegment::new(mock, None)],
+ )));
+
+ if let Some((_, lifetime)) = &receiver.reference {
+ return Type::Reference(TypeReference::new(
+ lifetime.clone(),
+ receiver.mutability.into(),
+ self_type,
+ ));
+ }
+
+ self_type
+ })
+}