From ad2a6d1dc517407939ed022bba9f3352efc678ce Mon Sep 17 00:00:00 2001
From: HampusM <hampus@hampusmat.com>
Date: Sun, 19 Mar 2023 16:29:23 +0100
Subject: feat: add call count expectations to expectations

---
 examples/generic_method.rs |  22 +++--
 macros/src/expectation.rs  | 226 ++++++++++++++++++++++++++++++++++++++-------
 macros/src/mock.rs         |  10 +-
 macros/src/util.rs         |  13 ++-
 src/lib.rs                 |   7 ++
 5 files changed, 229 insertions(+), 49 deletions(-)

diff --git a/examples/generic_method.rs b/examples/generic_method.rs
index 8dc4650..995c67d 100644
--- a/examples/generic_method.rs
+++ b/examples/generic_method.rs
@@ -1,8 +1,10 @@
+use std::fmt::Display;
+
 use ridicule::mock;
 
 trait Foo
 {
-    fn bar<Baz>(&self, num: u128) -> Baz;
+    fn bar<Baz: Display>(&self, num: u128) -> Baz;
 }
 
 mock! {
@@ -10,7 +12,7 @@ mock! {
 
     impl Foo for MockFoo
     {
-        fn bar<Baz>(&self, num: u128) -> Baz;
+        fn bar<Baz: Display>(&self, num: u128) -> Baz;
     }
 }
 
@@ -18,11 +20,14 @@ fn main()
 {
     let mut mock_foo = MockFoo::new();
 
-    mock_foo.expect_bar().returning(|_me, num| {
-        println!("bar was called with {num}");
+    mock_foo
+        .expect_bar()
+        .returning(|_me, num| {
+            println!("bar was called with {num}");
 
-        "Hello".to_string()
-    });
+            "Hello".to_string()
+        })
+        .times(3);
 
     mock_foo.expect_bar().returning(|_me, num| {
         println!("bar was called with {num}");
@@ -31,6 +36,11 @@ fn main()
     });
 
     assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string());
+    assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string());
+    assert_eq!(mock_foo.bar::<String>(123), "Hello".to_string());
+
+    // Would panic
+    // mock_foo.bar::<String>(123);
 
     assert_eq!(mock_foo.bar::<u8>(456), 128);
 }
diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs
index 7d8f1a7..436d571 100644
--- a/macros/src/expectation.rs
+++ b/macros/src/expectation.rs
@@ -16,6 +16,9 @@ use syn::{
     ImplItemMethod,
     ItemStruct,
     Lifetime,
+    Pat,
+    PatIdent,
+    PatType,
     Path,
     PathSegment,
     Receiver,
@@ -50,6 +53,7 @@ use crate::util::{create_path, create_unit_type_tuple};
 pub struct Expectation
 {
     ident: Ident,
+    method_ident: Ident,
     method_generics: Generics,
     generic_params: Punctuated<GenericParam, Token![,]>,
     receiver: Option<Receiver>,
@@ -104,6 +108,7 @@ impl Expectation
 
         Self {
             ident,
+            method_ident: item_method.sig.ident.clone(),
             method_generics: item_method.sig.generics.clone(),
             generic_params,
             receiver,
@@ -166,10 +171,74 @@ impl Expectation
             })
             .collect()
     }
+
+    fn create_struct(
+        ident: Ident,
+        generics: Generics,
+        phantom_fields: &[PhantomField],
+        returning_fn: &Type,
+    ) -> ItemStruct
+    {
+        ItemStruct {
+            attrs: vec![Attribute::new(
+                AttributeStyle::Outer,
+                create_path!(allow),
+                quote! { (non_camel_case_types, non_snake_case) },
+            )],
+            vis: Visibility::new_pub_crate(),
+            struct_token: <Token![struct]>::default(),
+            ident,
+            generics: generics.strip_where_clause_and_bounds(),
+            fields: Fields::Named(FieldsNamed {
+                brace_token: Brace::default(),
+                named: [
+                    Field {
+                        attrs: vec![],
+                        vis: Visibility::Inherited,
+                        ident: Some(format_ident!("returning")),
+                        colon_token: Some(<Token![:]>::default()),
+                        ty: Type::Path(TypePath::new(Path::new(
+                            WithLeadingColons::No,
+                            [PathSegment::new(
+                                format_ident!("Option"),
+                                Some(AngleBracketedGenericArguments::new(
+                                    WithColons::No,
+                                    [GenericArgument::Type(returning_fn.clone())],
+                                )),
+                            )],
+                        ))),
+                    },
+                    Field {
+                        attrs: vec![],
+                        vis: Visibility::Inherited,
+                        ident: Some(format_ident!("call_cnt")),
+                        colon_token: Some(<Token![:]>::default()),
+                        ty: Type::Path(TypePath::new(create_path!(
+                            ::std::sync::atomic::AtomicU32
+                        ))),
+                    },
+                    Field {
+                        attrs: vec![],
+                        vis: Visibility::Inherited,
+                        ident: Some(format_ident!("call_cnt_expectation")),
+                        colon_token: Some(<Token![:]>::default()),
+                        ty: Type::Path(TypePath::new(create_path!(
+                            ::ridicule::__private::CallCountExpectation
+                        ))),
+                    },
+                ]
+                .into_iter()
+                .chain(phantom_fields.iter().cloned().map(Field::from))
+                .collect(),
+            }),
+            semi_token: None,
+        }
+    }
 }
 
 impl ToTokens for Expectation
 {
+    #[allow(clippy::too_many_lines)]
     fn to_tokens(&self, tokens: &mut TokenStream)
     {
         let generics = {
@@ -199,40 +268,47 @@ impl ToTokens for Expectation
             self.return_type.clone(),
         ));
 
-        let expectation_struct = ItemStruct {
-            attrs: vec![Attribute::new(
-                AttributeStyle::Outer,
-                create_path!(allow),
-                quote! { (non_camel_case_types, non_snake_case) },
-            )],
-            vis: Visibility::new_pub_crate(),
-            struct_token: <Token![struct]>::default(),
-            ident: self.ident.clone(),
-            generics: generics.clone().strip_where_clause_and_bounds(),
-            fields: Fields::Named(FieldsNamed {
-                brace_token: Brace::default(),
-                named: [Field {
+        let args = opt_self_type
+            .iter()
+            .chain(self.arg_types.iter())
+            .enumerate()
+            .map(|(index, ty)| {
+                FnArg::Typed(PatType {
                     attrs: vec![],
-                    vis: Visibility::Inherited,
-                    ident: Some(format_ident!("returning")),
-                    colon_token: Some(<Token![:]>::default()),
-                    ty: Type::Path(TypePath::new(Path::new(
-                        WithLeadingColons::No,
-                        [PathSegment::new(
-                            format_ident!("Option"),
-                            Some(AngleBracketedGenericArguments::new(
-                                WithColons::No,
-                                [GenericArgument::Type(returning_fn.clone())],
-                            )),
-                        )],
-                    ))),
-                }]
-                .into_iter()
-                .chain(phantom_fields.iter().cloned().map(Field::from))
-                .collect(),
-            }),
-            semi_token: None,
-        };
+                    pat: Box::new(Pat::Ident(PatIdent {
+                        attrs: vec![],
+                        by_ref: None,
+                        mutability: None,
+                        ident: format_ident!("arg_{index}"),
+                        subpat: None,
+                    })),
+                    colon_token: <Token![:]>::default(),
+                    ty: Box::new(ty.clone()),
+                })
+            })
+            .collect::<Vec<_>>();
+
+        let arg_idents = opt_self_type
+            .iter()
+            .chain(self.arg_types.iter())
+            .enumerate()
+            .map(|(index, _)| format_ident!("arg_{index}"))
+            .collect::<Vec<_>>();
+
+        let return_type = &self.return_type;
+
+        let method_ident = &self.method_ident;
+
+        let expectation_struct = Self::create_struct(
+            self.ident.clone(),
+            generics.clone(),
+            phantom_fields,
+            &returning_fn,
+        );
+
+        let boundless_generics = generics.clone().strip_where_clause_and_bounds();
+
+        let (boundless_impl_generics, _, _) = boundless_generics.split_for_impl();
 
         quote! {
             #expectation_struct
@@ -242,6 +318,9 @@ impl ToTokens for Expectation
                 fn new() -> Self {
                     Self {
                         returning: None,
+                        call_cnt: ::std::sync::atomic::AtomicU32::new(0),
+                        call_cnt_expectation:
+                            ::ridicule::__private::CallCountExpectation::Unlimited,
                         #(#phantom_fields),*
                     }
                 }
@@ -257,6 +336,20 @@ impl ToTokens for Expectation
                     self
                 }
 
+                pub fn times(&mut self, cnt: u32) -> &mut Self {
+                    self.call_cnt_expectation =
+                        ::ridicule::__private::CallCountExpectation::Times(cnt);
+
+                    self
+                }
+
+                pub fn never(&mut self) -> &mut Self {
+                    self.call_cnt_expectation =
+                        ::ridicule::__private::CallCountExpectation::Never;
+
+                    self
+                }
+
                 #[allow(unused)]
                 fn strip_generic_params(
                     self,
@@ -264,6 +357,49 @@ impl ToTokens for Expectation
                 {
                     unsafe { std::mem::transmute(self) }
                 }
+
+                fn call_returning(&self, #(#args),*) #return_type
+                {
+                    let Some(returning) = &self.returning else {
+                        panic!(concat!(
+                            "Expectation for function",
+                            stringify!(#method_ident),
+                            " is missing a function to call")
+                        );
+                    };
+
+                    if matches!(
+                        self.call_cnt_expectation,
+                        ::ridicule::__private::CallCountExpectation::Never
+                    ) {
+                        panic!(
+                            "Expected function {} to never be called",
+                            stringify!(#method_ident)
+                        );
+                    }
+
+                    if let ::ridicule::__private::CallCountExpectation::Times(
+                        times
+                    ) = self.call_cnt_expectation {
+                        if times == self.call_cnt.load(
+                            ::std::sync::atomic::Ordering::Relaxed
+                        ) {
+                            panic!(
+                                concat!(
+                                    "Expected function {} to be called {} times. Was ",
+                                    "called {} times"
+                                ),
+                                stringify!(#method_ident),
+                                times,
+                                times + 1
+                            );
+                        }
+                    }
+
+                    self.call_cnt.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
+
+                    (returning)(#(#arg_idents),*)
+                }
             }
 
             impl #ident<#(#bogus_generics),*> {
@@ -291,6 +427,30 @@ impl ToTokens for Expectation
                     unsafe { &mut *(self as *mut Self).cast() }
                 }
             }
+
+            impl #boundless_impl_generics Drop for #ident #ty_generics
+            {
+                fn drop(&mut self) {
+                    let call_cnt =
+                        self.call_cnt.load(::std::sync::atomic::Ordering::Relaxed);
+
+                    if let ::ridicule::__private::CallCountExpectation::Times(
+                        times
+                    ) = self.call_cnt_expectation {
+                        if call_cnt != times {
+                            panic!(
+                                concat!(
+                                    "Expected function {} to be called {} times. Was ",
+                                    "called {} times"
+                                ),
+                                stringify!(#method_ident),
+                                times,
+                                call_cnt
+                            );
+                        }
+                    }
+                }
+            }
         }
         .to_tokens(tokens);
     }
diff --git a/macros/src/mock.rs b/macros/src/mock.rs
index 8828b17..d2eb451 100644
--- a/macros/src/mock.rs
+++ b/macros/src/mock.rs
@@ -247,15 +247,7 @@ fn create_mock_function(
                     ))
                     .with_generic_params::<#(#type_param_idents,)*>();
 
-                let Some(returning) = &expectation.returning else {
-                    panic!(concat!(
-                        "Expectation for function",
-                        stringify!(#func_ident),
-                        " is missing a function to call")
-                    );
-                };
-
-                returning(#(#args),*)
+                expectation.call_returning(#(#args),*)
             }
         })
         .unwrap_or_abort(),
diff --git a/macros/src/util.rs b/macros/src/util.rs
index 363051f..43779c1 100644
--- a/macros/src/util.rs
+++ b/macros/src/util.rs
@@ -14,7 +14,18 @@ macro_rules! create_path {
     ($($segment: ident)::+) => {
         Path::new(
             WithLeadingColons::No,
-            [$(PathSegment::new(format_ident!(stringify!($segment)), None))+],
+            [$(
+                PathSegment::new(format_ident!(stringify!($segment)), None)
+            ),+],
+        )
+    };
+
+    (::$($segment: ident)::+) => {
+        ::syn::Path::new(
+            WithLeadingColons::Yes,
+            [$(
+                ::syn::PathSegment::new(format_ident!(stringify!($segment)), None)
+            ),+],
         )
     };
 }
diff --git a/src/lib.rs b/src/lib.rs
index 3011354..3a1d983 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -26,4 +26,11 @@ pub mod __private
             }
         }
     }
+
+    pub enum CallCountExpectation
+    {
+        Never,
+        Unlimited,
+        Times(u32),
+    }
 }
-- 
cgit v1.2.3-18-g5258