summaryrefslogtreecommitdiff
path: root/macros
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2023-03-18 21:26:54 +0100
committerHampusM <hampus@hampusmat.com>2023-03-18 21:26:54 +0100
commit2d964b39da09ad82eccf09abdea73967bbff76f2 (patch)
treed5b43196d2402e62559e999adb65ef99f584eaf7 /macros
parent43e0bdb4cc598f199eacb63f755f30dc2108146b (diff)
feat: add support for generic traits
Diffstat (limited to 'macros')
-rw-r--r--macros/src/expectation.rs61
-rw-r--r--macros/src/lib.rs17
-rw-r--r--macros/src/mock.rs77
3 files changed, 112 insertions, 43 deletions
diff --git a/macros/src/expectation.rs b/macros/src/expectation.rs
index 4fc1451..d35ef97 100644
--- a/macros/src/expectation.rs
+++ b/macros/src/expectation.rs
@@ -50,7 +50,8 @@ use crate::util::{create_path, create_unit_type_tuple};
pub struct Expectation
{
ident: Ident,
- generics: Generics,
+ method_generics: Generics,
+ generic_params: Punctuated<GenericParam, Token![,]>,
receiver: Option<Receiver>,
mock: Ident,
arg_types: Vec<Type>,
@@ -60,12 +61,24 @@ pub struct Expectation
impl Expectation
{
- pub fn new(mock: &Ident, item_method: &ImplItemMethod) -> Self
+ pub fn new(
+ mock: &Ident,
+ item_method: &ImplItemMethod,
+ generic_params: Punctuated<GenericParam, Token![,]>,
+ ) -> Self
{
let ident = create_expectation_ident(mock, &item_method.sig.ident);
- let phantom_fields =
- Self::create_phantom_fields(&item_method.sig.generics.params);
+ let phantom_fields = Self::create_phantom_fields(
+ &item_method
+ .sig
+ .generics
+ .params
+ .clone()
+ .into_iter()
+ .chain(generic_params.clone())
+ .collect(),
+ );
let receiver =
item_method
@@ -91,7 +104,8 @@ impl Expectation
Self {
ident,
- generics: item_method.sig.generics.clone(),
+ method_generics: item_method.sig.generics.clone(),
+ generic_params,
receiver,
mock: mock.clone(),
arg_types,
@@ -158,9 +172,17 @@ impl ToTokens for Expectation
{
fn to_tokens(&self, tokens: &mut TokenStream)
{
- let generic_params = &self.generics.params;
+ let generics = {
+ let mut generics = self.method_generics.clone();
- let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
+ generics.params.extend(self.generic_params.clone());
+
+ generics
+ };
+
+ let generic_params = &generics.params;
+
+ let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let bogus_generics = create_bogus_generics(generic_params);
@@ -186,7 +208,7 @@ impl ToTokens for Expectation
vis: Visibility::new_pub_crate(),
struct_token: <Token![struct]>::default(),
ident: self.ident.clone(),
- generics: self.generics.clone().without_where_clause(),
+ generics: generics.clone().without_where_clause(),
fields: Fields::Named(FieldsNamed {
brace_token: Brace::default(),
named: [Field {
@@ -206,13 +228,7 @@ impl ToTokens for Expectation
))),
}]
.into_iter()
- .chain(phantom_fields.iter().map(|phantom_field| Field {
- attrs: vec![],
- vis: Visibility::Inherited,
- ident: Some(phantom_field.field.clone()),
- colon_token: Some(<Token![:]>::default()),
- ty: Type::Path(phantom_field.type_path.clone()),
- }))
+ .chain(phantom_fields.iter().cloned().map(Field::from))
.collect(),
}),
semi_token: None,
@@ -285,6 +301,7 @@ pub fn create_expectation_ident(mock: &Ident, method: &Ident) -> Ident
format_ident!("{mock}Expectation_{method}")
}
+#[derive(Clone)]
struct PhantomField
{
field: Ident,
@@ -303,6 +320,20 @@ impl ToTokens for PhantomField
}
}
+impl From<PhantomField> for Field
+{
+ fn from(phantom_field: PhantomField) -> Self
+ {
+ Self {
+ attrs: vec![],
+ vis: Visibility::Inherited,
+ ident: Some(phantom_field.field.clone()),
+ colon_token: Some(<Token![:]>::default()),
+ ty: Type::Path(phantom_field.type_path),
+ }
+ }
+}
+
fn create_phantom_field_ident(ident: &Ident, kind: &PhantomFieldKind) -> Ident
{
match kind {
diff --git a/macros/src/lib.rs b/macros/src/lib.rs
index 36c6ad7..8106a8c 100644
--- a/macros/src/lib.rs
+++ b/macros/src/lib.rs
@@ -34,11 +34,20 @@ pub fn mock(input_stream: TokenStream) -> TokenStream
})
.collect::<Vec<_>>();
- let mock = Mock::new(mock_ident.clone(), input.mocked_trait, &method_items);
+ let mock = Mock::new(
+ mock_ident.clone(),
+ input.mocked_trait,
+ &method_items,
+ input.item_impl.generics.clone(),
+ );
- let expectations = method_items
- .iter()
- .map(|item_method| Expectation::new(&mock_ident, item_method));
+ let expectations = method_items.iter().map(|item_method| {
+ Expectation::new(
+ &mock_ident,
+ item_method,
+ input.item_impl.generics.params.clone(),
+ )
+ });
quote! {
mod #mock_mod_ident {
diff --git a/macros/src/mock.rs b/macros/src/mock.rs
index 1b6dd67..8828b17 100644
--- a/macros/src/mock.rs
+++ b/macros/src/mock.rs
@@ -1,12 +1,14 @@
use proc_macro2::{Ident, Span, TokenStream};
use proc_macro_error::{abort_call_site, ResultExt};
use quote::{format_ident, quote, ToTokens};
+use syn::punctuated::Punctuated;
use syn::{
parse2,
AngleBracketedGenericArguments,
FnArg,
GenericArgument,
GenericParam,
+ Generics,
ImplItemMethod,
Lifetime,
Pat,
@@ -45,12 +47,17 @@ pub struct Mock
mocked_trait: Path,
expectations_fields: Vec<ExpectationsField>,
item_methods: Vec<ImplItemMethod>,
+ generics: Generics,
}
impl Mock
{
- pub fn new(ident: Ident, mocked_trait: Path, item_methods: &[ImplItemMethod])
- -> Self
+ pub fn new(
+ ident: Ident,
+ mocked_trait: Path,
+ item_methods: &[ImplItemMethod],
+ generics: Generics,
+ ) -> Self
{
let expectations_fields = item_methods
.iter()
@@ -60,6 +67,7 @@ impl Mock
.generics
.params
.iter()
+ .chain(generics.params.iter())
.filter_map(|generic_param| match generic_param {
GenericParam::Type(_) => Some(GenericArgument::Type(
Type::Tuple(create_unit_type_tuple()),
@@ -90,6 +98,7 @@ impl Mock
mocked_trait,
expectations_fields,
item_methods: item_methods.to_vec(),
+ generics,
}
}
}
@@ -103,21 +112,26 @@ impl ToTokens for Mock
mocked_trait,
expectations_fields,
item_methods,
+ generics,
} = self;
let expectations_field_idents = expectations_fields
.iter()
.map(|expectations_field| expectations_field.field_ident.clone());
- let mock_functions = item_methods
- .iter()
- .map(|item_method| create_mock_function(item_method.clone()));
+ let mock_functions = item_methods.iter().map(|item_method| {
+ create_mock_function(item_method.clone(), &generics.params)
+ });
let expect_functions = item_methods
.iter()
- .map(|item_method| create_expect_function(&self.ident, &item_method.clone()))
+ .map(|item_method| {
+ create_expect_function(&self.ident, &item_method.clone(), generics)
+ })
.collect::<Vec<_>>();
+ let (impl_generics, _, _) = generics.split_for_impl();
+
quote! {
pub struct #ident
{
@@ -138,7 +152,7 @@ impl ToTokens for Mock
#(#expect_functions)*
}
- impl #mocked_trait for #ident {
+ impl #impl_generics #mocked_trait for #ident {
#(
#mock_functions
)*
@@ -175,15 +189,24 @@ impl ToTokens for ExpectationsField
}
}
-fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
+fn create_mock_function(
+ item_method: ImplItemMethod,
+ generic_params: &Punctuated<GenericParam, Token![,]>,
+) -> ImplItemMethod
{
let func_ident = &item_method.sig.ident;
let type_param_idents = item_method
.sig
.generics
- .type_params()
- .map(|type_param| type_param.ident.clone())
+ .params
+ .clone()
+ .into_iter()
+ .chain(generic_params.clone())
+ .filter_map(|generic_param| match generic_param {
+ GenericParam::Type(type_param) => Some(type_param.ident),
+ _ => None,
+ })
.collect::<Vec<_>>();
let args = item_method
@@ -204,12 +227,6 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
let expectations_field = format_ident!("{func_ident}_expectations");
- let ids = quote! {
- let ids = vec![
- #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),*
- ];
- };
-
ImplItemMethod {
attrs: item_method.attrs,
vis: Visibility::Inherited,
@@ -217,7 +234,9 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
sig: item_method.sig.clone(),
block: parse2(quote! {
{
- #ids
+ let ids = vec![
+ #(::ridicule::__private::type_id::TypeID::of::<#type_param_idents>()),*
+ ];
let expectation = self
.#expectations_field
@@ -226,7 +245,7 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
"No expectation found for function ",
stringify!(#func_ident)
))
- .with_generic_params::<#(#type_param_idents),*>();
+ .with_generic_params::<#(#type_param_idents,)*>();
let Some(returning) = &expectation.returning else {
panic!(concat!(
@@ -243,11 +262,23 @@ fn create_mock_function(item_method: ImplItemMethod) -> ImplItemMethod
}
}
-fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplItemMethod
+fn create_expect_function(
+ mock: &Ident,
+ item_method: &ImplItemMethod,
+ generics: &Generics,
+) -> ImplItemMethod
{
+ let signature_generics = {
+ let mut sig_generics = item_method.sig.generics.clone();
+
+ sig_generics.params.extend(generics.params.clone());
+
+ sig_generics
+ };
+
let signature = Signature::new(
format_ident!("expect_{}", item_method.sig.ident),
- item_method.sig.generics.clone(),
+ signature_generics.clone(),
[FnArg::Receiver(Receiver {
attrs: vec![],
reference: Some((<Token![&]>::default(), None)),
@@ -263,7 +294,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte
create_expectation_ident(mock, &item_method.sig.ident),
Some(AngleBracketedGenericArguments::new(
WithColons::No,
- item_method.sig.generics.params.iter().map(|generic_param| {
+ signature_generics.params.iter().map(|generic_param| {
GenericArgument::from_generic_param(generic_param.clone())
}),
)),
@@ -272,9 +303,7 @@ fn create_expect_function(mock: &Ident, item_method: &ImplItemMethod) -> ImplIte
))),
);
- let type_param_idents = item_method
- .sig
- .generics
+ let type_param_idents = signature_generics
.type_params()
.map(|type_param| type_param.ident.clone())
.collect::<Vec<_>>();