diff options
| -rw-r--r-- | examples/generic_method.rs | 22 | ||||
| -rw-r--r-- | macros/src/expectation.rs | 226 | ||||
| -rw-r--r-- | macros/src/mock.rs | 10 | ||||
| -rw-r--r-- | macros/src/util.rs | 13 | ||||
| -rw-r--r-- | src/lib.rs | 7 | 
5 files changed, 229 insertions, 49 deletions
| diff --git a/examples/generic_method.rs b/examples/generic_method.rs index 8dc4650..995c67d 100644 --- a/examples/generic_method.rs +++ b/examples/generic_method.rs @@ -1,8 +1,10 @@ +use std::fmt::Display; +  use ridicule::mock;  trait Foo  { -    fn bar<Baz>(&self, num: u128) -> Baz; +    fn bar<Baz: Display>(&self, num: u128) -> Baz;  }  mock! { @@ -10,7 +12,7 @@ mock! {      impl Foo for MockFoo      { -        fn bar<Baz>(&self, num: u128) -> Baz; +        fn bar<Baz: Display>(&self, num: u128) -> Baz;      }  } @@ -18,11 +20,14 @@ fn main()  {      let mut mock_foo = MockFoo::new(); -    mock_foo.expect_bar().returning(|_me, num| { -        println!("bar was called with {num}"); +    mock_foo +        .expect_bar() +        .returning(|_me, num| { +            println!("bar was called with {num}"); -        "Hello".to_string() -    }); +            "Hello".to_string() +        }) +        .times(3);      mock_foo.expect_bar().returning(|_me, num| {          println!("bar was called with {num}"); @@ -31,6 +36,11 @@ fn main()      });      assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string()); +    assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string()); +    assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string()); + +    // Would panic +    // mock_foo.bar::<String>(123);      assert_eq!(mock_foo.bar::<u8>(456), 128);  } diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs index 7d8f1a7..436d571 100644 --- a/macros/src/expectation.rs +++ b/macros/src/expectation.rs @@ -16,6 +16,9 @@ use syn::{      ImplItemMethod,      ItemStruct,      Lifetime, +    Pat, +    PatIdent, +    PatType,      Path,      PathSegment,      Receiver, @@ -50,6 +53,7 @@ use crate::util::{create_path, create_unit_type_tuple};  pub struct Expectation  {      ident: Ident, +    method_ident: Ident,      method_generics: Generics,      generic_params: Punctuated<GenericParam, Token![,]>,      receiver: Option<Receiver>, @@ -104,6 +108,7 @@ impl Expectation          Self {              ident, +            method_ident: item_method.sig.ident.clone(),              method_generics: item_method.sig.generics.clone(),              generic_params,              receiver, @@ -166,10 +171,74 @@ impl Expectation              })              .collect()      } + +    fn create_struct( +        ident: Ident, +        generics: Generics, +        phantom_fields: &[PhantomField], +        returning_fn: &Type, +    ) -> ItemStruct +    { +        ItemStruct { +            attrs: vec![Attribute::new( +                AttributeStyle::Outer, +                create_path!(allow), +                quote! { (non_camel_case_types, non_snake_case) }, +            )], +            vis: Visibility::new_pub_crate(), +            struct_token: <Token![struct]>::default(), +            ident, +            generics: generics.strip_where_clause_and_bounds(), +            fields: Fields::Named(FieldsNamed { +                brace_token: Brace::default(), +                named: [ +                    Field { +                        attrs: vec![], +                        vis: Visibility::Inherited, +                        ident: Some(format_ident!("returning")), +                        colon_token: Some(<Token![:]>::default()), +                        ty: Type::Path(TypePath::new(Path::new( +                            WithLeadingColons::No, +                            [PathSegment::new( +                                format_ident!("Option"), +                                Some(AngleBracketedGenericArguments::new( +                                    WithColons::No, +                                    [GenericArgument::Type(returning_fn.clone())], +                                )), +                            )], +                        ))), +                    }, +                    Field { +                        attrs: vec![], +                        vis: Visibility::Inherited, +                        ident: Some(format_ident!("call_cnt")), +                        colon_token: Some(<Token![:]>::default()), +                        ty: Type::Path(TypePath::new(create_path!( +                            ::std::sync::atomic::AtomicU32 +                        ))), +                    }, +                    Field { +                        attrs: vec![], +                        vis: Visibility::Inherited, +                        ident: Some(format_ident!("call_cnt_expectation")), +                        colon_token: Some(<Token![:]>::default()), +                        ty: Type::Path(TypePath::new(create_path!( +                            ::ridicule::__private::CallCountExpectation +                        ))), +                    }, +                ] +                .into_iter() +                .chain(phantom_fields.iter().cloned().map(Field::from)) +                .collect(), +            }), +            semi_token: None, +        } +    }  }  impl ToTokens for Expectation  { +    #[allow(clippy::too_many_lines)]      fn to_tokens(&self, tokens: &mut TokenStream)      {          let generics = { @@ -199,40 +268,47 @@ impl ToTokens for Expectation              self.return_type.clone(),          )); -        let expectation_struct = ItemStruct { -            attrs: vec![Attribute::new( -                AttributeStyle::Outer, -                create_path!(allow), -                quote! { (non_camel_case_types, non_snake_case) }, -            )], -            vis: Visibility::new_pub_crate(), -            struct_token: <Token![struct]>::default(), -            ident: self.ident.clone(), -            generics: generics.clone().strip_where_clause_and_bounds(), -            fields: Fields::Named(FieldsNamed { -                brace_token: Brace::default(), -                named: [Field { +        let args = opt_self_type +            .iter() +            .chain(self.arg_types.iter()) +            .enumerate() +            .map(|(index, ty)| { +                FnArg::Typed(PatType {                      attrs: vec![], -                    vis: Visibility::Inherited, -                    ident: Some(format_ident!("returning")), -                    colon_token: Some(<Token![:]>::default()), -                    ty: Type::Path(TypePath::new(Path::new( -                        WithLeadingColons::No, -                        [PathSegment::new( -                            format_ident!("Option"), -                            Some(AngleBracketedGenericArguments::new( -                                WithColons::No, -                                [GenericArgument::Type(returning_fn.clone())], -                            )), -                        )], -                    ))), -                }] -                .into_iter() -                .chain(phantom_fields.iter().cloned().map(Field::from)) -                .collect(), -            }), -            semi_token: None, -        }; +                    pat: Box::new(Pat::Ident(PatIdent { +                        attrs: vec![], +                        by_ref: None, +                        mutability: None, +                        ident: format_ident!("arg_{index}"), +                        subpat: None, +                    })), +                    colon_token: <Token![:]>::default(), +                    ty: Box::new(ty.clone()), +                }) +            }) +            .collect::<Vec<_>>(); + +        let arg_idents = opt_self_type +            .iter() +            .chain(self.arg_types.iter()) +            .enumerate() +            .map(|(index, _)| format_ident!("arg_{index}")) +            .collect::<Vec<_>>(); + +        let return_type = &self.return_type; + +        let method_ident = &self.method_ident; + +        let expectation_struct = Self::create_struct( +            self.ident.clone(), +            generics.clone(), +            phantom_fields, +            &returning_fn, +        ); + +        let boundless_generics = generics.clone().strip_where_clause_and_bounds(); + +        let (boundless_impl_generics, _, _) = boundless_generics.split_for_impl();          quote! {              #expectation_struct @@ -242,6 +318,9 @@ impl ToTokens for Expectation                  fn new() -> Self {                      Self {                          returning: None, +                        call_cnt: ::std::sync::atomic::AtomicU32::new(0), +                        call_cnt_expectation: +                            ::ridicule::__private::CallCountExpectation::Unlimited,                          #(#phantom_fields),*                      }                  } @@ -257,6 +336,20 @@ impl ToTokens for Expectation                      self                  } +                pub fn times(&mut self, cnt: u32) -> &mut Self { +                    self.call_cnt_expectation = +                        ::ridicule::__private::CallCountExpectation::Times(cnt); + +                    self +                } + +                pub fn never(&mut self) -> &mut Self { +                    self.call_cnt_expectation = +                        ::ridicule::__private::CallCountExpectation::Never; + +                    self +                } +                  #[allow(unused)]                  fn strip_generic_params(                      self, @@ -264,6 +357,49 @@ impl ToTokens for Expectation                  {                      unsafe { std::mem::transmute(self) }                  } + +                fn call_returning(&self, #(#args),*) #return_type +                { +                    let Some(returning) = &self.returning else { +                        panic!(concat!( +                            "Expectation for function", +                            stringify!(#method_ident), +                            " is missing a function to call") +                        ); +                    }; + +                    if matches!( +                        self.call_cnt_expectation, +                        ::ridicule::__private::CallCountExpectation::Never +                    ) { +                        panic!( +                            "Expected function {} to never be called", +                            stringify!(#method_ident) +                        ); +                    } + +                    if let ::ridicule::__private::CallCountExpectation::Times( +                        times +                    ) = self.call_cnt_expectation { +                        if times == self.call_cnt.load( +                            ::std::sync::atomic::Ordering::Relaxed +                        ) { +                            panic!( +                                concat!( +                                    "Expected function {} to be called {} times. Was ", +                                    "called {} times" +                                ), +                                stringify!(#method_ident), +                                times, +                                times + 1 +                            ); +                        } +                    } + +                    self.call_cnt.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + +                    (returning)(#(#arg_idents),*) +                }              }              impl #ident<#(#bogus_generics),*> { @@ -291,6 +427,30 @@ impl ToTokens for Expectation                      unsafe { &mut *(self as *mut Self).cast() }                  }              } + +            impl #boundless_impl_generics Drop for #ident #ty_generics +            { +                fn drop(&mut self) { +                    let call_cnt = +                        self.call_cnt.load(::std::sync::atomic::Ordering::Relaxed); + +                    if let ::ridicule::__private::CallCountExpectation::Times( +                        times +                    ) = self.call_cnt_expectation { +                        if call_cnt != times { +                            panic!( +                                concat!( +                                    "Expected function {} to be called {} times. Was ", +                                    "called {} times" +                                ), +                                stringify!(#method_ident), +                                times, +                                call_cnt +                            ); +                        } +                    } +                } +            }          }          .to_tokens(tokens);      } diff --git a/macros/src/mock.rs b/macros/src/mock.rs index 8828b17..d2eb451 100644 --- a/macros/src/mock.rs +++ b/macros/src/mock.rs @@ -247,15 +247,7 @@ fn create_mock_function(                      ))                      .with_generic_params::<#(#type_param_idents,)*>(); -                let Some(returning) = &expectation.returning else { -                    panic!(concat!( -                        "Expectation for function", -                        stringify!(#func_ident), -                        " is missing a function to call") -                    ); -                }; - -                returning(#(#args),*) +                expectation.call_returning(#(#args),*)              }          })          .unwrap_or_abort(), diff --git a/macros/src/util.rs b/macros/src/util.rs index 363051f..43779c1 100644 --- a/macros/src/util.rs +++ b/macros/src/util.rs @@ -14,7 +14,18 @@ macro_rules! create_path {      ($($segment: ident)::+) => {          Path::new(              WithLeadingColons::No, -            [$(PathSegment::new(format_ident!(stringify!($segment)), None))+], +            [$( +                PathSegment::new(format_ident!(stringify!($segment)), None) +            ),+], +        ) +    }; + +    (::$($segment: ident)::+) => { +        ::syn::Path::new( +            WithLeadingColons::Yes, +            [$( +                ::syn::PathSegment::new(format_ident!(stringify!($segment)), None) +            ),+],          )      };  } @@ -26,4 +26,11 @@ pub mod __private              }          }      } + +    pub enum CallCountExpectation +    { +        Never, +        Unlimited, +        Times(u32), +    }  } | 
