//! Macros for Ridicule, a mocking library supporting non-static generics.
#![deny(clippy::all, clippy::pedantic, missing_docs)]
use proc_macro::TokenStream;
use proc_macro2::Ident;
use proc_macro_error::{proc_macro_error, ResultExt};
use quote::{format_ident, quote};
use syn::token::Brace;
use syn::{
parse,
Block,
FnArg,
GenericArgument,
GenericParam,
Generics,
ImplItem,
ImplItemMethod,
ItemTrait,
Path,
PathArguments,
PathSegment,
ReturnType,
TraitItem,
Type,
TypeBareFn,
TypeParamBound,
Visibility,
WherePredicate,
};
use crate::expectation::Expectation;
use crate::mock::Mock;
use crate::mock_input::MockInput;
use crate::syn_ext::{PathExt, PathSegmentExt, WithLeadingColons};
use crate::util::create_path;
mod expectation;
mod mock;
mod mock_input;
mod syn_ext;
mod util;
/// Creates a mock.
///
/// # Examples
/// ```
/// use ridicule::mock;
///
/// trait Foo
/// {
/// fn bar(&self, a: A) -> B;
/// }
///
/// mock! {
/// MockFoo {}
///
/// impl Foo for MockFoo
/// {
/// fn bar(&self, a: A) -> B;
/// }
/// }
///
/// fn main()
/// {
/// let mut mock_foo = MockFoo::new();
///
/// unsafe {
/// mock_foo
/// .expect_bar()
/// .returning(|foo, a: u32| format!("Hello {a}"));
/// }
///
/// assert_eq!(mock_foo.bar::(123), "Hello 123");
/// }
/// ```
#[proc_macro]
#[proc_macro_error]
pub fn mock(input_stream: TokenStream) -> TokenStream
{
let input = parse::(input_stream.clone()).unwrap_or_abort();
let mock_ident = input.mock;
let mock_mod_ident = format_ident!("__{mock_ident}");
let method_items =
get_type_replaced_impl_item_methods(input.item_impl.items, &mock_ident);
let mock = Mock::new(
mock_ident.clone(),
input.mocked_trait,
&method_items,
input.item_impl.generics.clone(),
);
let expectations = method_items.iter().map(|item_method| {
Expectation::new(
&mock_ident,
item_method,
input.item_impl.generics.params.clone(),
)
});
quote! {
mod #mock_mod_ident {
use super::*;
#mock
#(#expectations)*
}
use #mock_mod_ident::#mock_ident;
}
.into()
}
/// Creates a mock automatically.
#[proc_macro_attribute]
#[proc_macro_error]
pub fn automock(_: TokenStream, input_stream: TokenStream) -> TokenStream
{
let item_trait = parse::(input_stream).unwrap_or_abort();
let mock_ident = format_ident!("Mock{}", item_trait.ident);
let mock_mod_ident = format_ident!("__{mock_ident}");
let method_items = get_type_replaced_impl_item_methods(
item_trait.items.iter().filter_map(|item| match item {
TraitItem::Method(item_method) => Some(ImplItem::Method(ImplItemMethod {
attrs: item_method.attrs.clone(),
vis: Visibility::Inherited,
defaultness: None,
sig: item_method.sig.clone(),
block: Block {
brace_token: Brace::default(),
stmts: vec![],
},
})),
_ => None,
}),
&mock_ident,
);
let mock = Mock::new(
mock_ident.clone(),
Path::new(
WithLeadingColons::No,
[PathSegment::new(item_trait.ident.clone(), None)],
),
&method_items,
item_trait.generics.clone(),
);
let expectations = method_items.iter().map(|item_method| {
Expectation::new(&mock_ident, item_method, item_trait.generics.params.clone())
});
let visibility = &item_trait.vis;
quote! {
#item_trait
mod #mock_mod_ident {
use super::*;
#mock
#(#expectations)*
}
#visibility use #mock_mod_ident::#mock_ident;
}
.into()
}
fn get_type_replaced_impl_item_methods(
impl_items: impl IntoIterator- ,
mock_ident: &Ident,
) -> Vec
{
let target_path = create_path!(Self);
let replacement_path = Path::new(
WithLeadingColons::No,
[PathSegment::new(mock_ident.clone(), None)],
);
impl_items
.into_iter()
.filter_map(|item| match item {
ImplItem::Method(mut item_method) => {
item_method.sig.inputs = item_method
.sig
.inputs
.into_iter()
.map(|fn_arg| match fn_arg {
FnArg::Typed(mut typed_arg) => {
typed_arg.ty = Box::new(replace_path_in_type(
*typed_arg.ty,
&target_path,
&replacement_path,
));
FnArg::Typed(typed_arg)
}
FnArg::Receiver(receiver) => FnArg::Receiver(receiver),
})
.collect();
item_method.sig.output = match item_method.sig.output {
ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
r_arrow,
Box::new(replace_path_in_type(
*return_type,
&target_path,
&replacement_path,
)),
),
ReturnType::Default => ReturnType::Default,
};
item_method.sig.generics = replace_path_in_generics(
item_method.sig.generics,
&target_path,
&replacement_path,
);
Some(item_method)
}
_ => None,
})
.collect()
}
fn replace_path_in_generics(
mut generics: Generics,
target_path: &Path,
replacement_path: &Path,
) -> Generics
{
generics.params = generics
.params
.into_iter()
.map(|generic_param| match generic_param {
GenericParam::Type(mut type_param) => {
type_param.bounds = type_param
.bounds
.into_iter()
.map(|bound| {
replace_type_param_bound_paths(
bound,
target_path,
replacement_path,
)
})
.collect();
GenericParam::Type(type_param)
}
generic_param => generic_param,
})
.collect();
generics.where_clause = generics.where_clause.map(|mut where_clause| {
where_clause.predicates = where_clause
.predicates
.into_iter()
.map(|predicate| match predicate {
WherePredicate::Type(mut predicate_type) => {
predicate_type.bounded_ty = replace_path_in_type(
predicate_type.bounded_ty,
target_path,
replacement_path,
);
predicate_type.bounds = predicate_type
.bounds
.into_iter()
.map(|bound| {
replace_type_param_bound_paths(
bound,
target_path,
replacement_path,
)
})
.collect();
WherePredicate::Type(predicate_type)
}
predicate => predicate,
})
.collect();
where_clause
});
generics
}
fn replace_path_in_type(ty: Type, target_path: &Path, replacement_path: &Path) -> Type
{
match ty {
Type::Ptr(mut type_ptr) => {
type_ptr.elem = Box::new(replace_path_in_type(
*type_ptr.elem,
target_path,
replacement_path,
));
Type::Ptr(type_ptr)
}
Type::Path(mut type_path) => {
if &type_path.path == target_path {
type_path.path = replacement_path.clone();
} else {
type_path.path =
replace_path_args(type_path.path, target_path, replacement_path);
}
Type::Path(type_path)
}
Type::Array(mut type_array) => {
type_array.elem = Box::new(replace_path_in_type(
*type_array.elem,
target_path,
replacement_path,
));
Type::Array(type_array)
}
Type::Group(mut type_group) => {
type_group.elem = Box::new(replace_path_in_type(
*type_group.elem,
target_path,
replacement_path,
));
Type::Group(type_group)
}
Type::BareFn(type_bare_fn) => Type::BareFn(replace_type_bare_fn_type_paths(
type_bare_fn,
target_path,
replacement_path,
)),
Type::Paren(mut type_paren) => {
type_paren.elem = Box::new(replace_path_in_type(
*type_paren.elem,
target_path,
replacement_path,
));
Type::Paren(type_paren)
}
Type::Slice(mut type_slice) => {
type_slice.elem = Box::new(replace_path_in_type(
*type_slice.elem,
target_path,
replacement_path,
));
Type::Slice(type_slice)
}
Type::Tuple(mut type_tuple) => {
type_tuple.elems = type_tuple
.elems
.into_iter()
.map(|elem_type| {
replace_path_in_type(elem_type, target_path, replacement_path)
})
.collect();
Type::Tuple(type_tuple)
}
Type::Reference(mut type_reference) => {
type_reference.elem = Box::new(replace_path_in_type(
*type_reference.elem,
target_path,
replacement_path,
));
Type::Reference(type_reference)
}
Type::TraitObject(mut type_trait_object) => {
type_trait_object.bounds = type_trait_object
.bounds
.into_iter()
.map(|bound| match bound {
TypeParamBound::Trait(mut trait_bound) => {
trait_bound.path = replace_path_args(
trait_bound.path,
target_path,
replacement_path,
);
TypeParamBound::Trait(trait_bound)
}
TypeParamBound::Lifetime(lifetime) => {
TypeParamBound::Lifetime(lifetime)
}
})
.collect();
Type::TraitObject(type_trait_object)
}
other_type => other_type,
}
}
fn replace_path_args(mut path: Path, target_path: &Path, replacement_path: &Path)
-> Path
{
path.segments = path
.segments
.into_iter()
.map(|mut segment| {
segment.arguments = match segment.arguments {
PathArguments::AngleBracketed(mut generic_args) => {
generic_args.args = generic_args
.args
.into_iter()
.map(|generic_arg| match generic_arg {
GenericArgument::Type(ty) => GenericArgument::Type(
replace_path_in_type(ty, target_path, replacement_path),
),
GenericArgument::Binding(mut binding) => {
binding.ty = replace_path_in_type(
binding.ty,
target_path,
replacement_path,
);
GenericArgument::Binding(binding)
}
generic_arg => generic_arg,
})
.collect();
PathArguments::AngleBracketed(generic_args)
}
PathArguments::Parenthesized(mut generic_args) => {
generic_args.inputs = generic_args
.inputs
.into_iter()
.map(|input_ty| {
replace_path_in_type(input_ty, target_path, replacement_path)
})
.collect();
generic_args.output = match generic_args.output {
ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
r_arrow,
Box::new(replace_path_in_type(
*return_type,
target_path,
replacement_path,
)),
),
ReturnType::Default => ReturnType::Default,
};
PathArguments::Parenthesized(generic_args)
}
PathArguments::None => PathArguments::None,
};
segment
})
.collect();
path
}
fn replace_type_bare_fn_type_paths(
mut type_bare_fn: TypeBareFn,
target_path: &Path,
replacement_path: &Path,
) -> TypeBareFn
{
type_bare_fn.inputs = type_bare_fn
.inputs
.into_iter()
.map(|mut bare_fn_arg| {
bare_fn_arg.ty =
replace_path_in_type(bare_fn_arg.ty, target_path, replacement_path);
bare_fn_arg
})
.collect();
type_bare_fn.output = match type_bare_fn.output {
ReturnType::Type(r_arrow, return_type) => ReturnType::Type(
r_arrow,
Box::new(replace_path_in_type(
*return_type,
target_path,
replacement_path,
)),
),
ReturnType::Default => ReturnType::Default,
};
type_bare_fn
}
fn replace_type_param_bound_paths(
type_param_bound: TypeParamBound,
target_path: &Path,
replacement_path: &Path,
) -> TypeParamBound
{
match type_param_bound {
TypeParamBound::Trait(mut trait_bound) => {
if &trait_bound.path == target_path {
trait_bound.path = replacement_path.clone();
} else {
trait_bound.path =
replace_path_args(trait_bound.path, target_path, replacement_path);
}
TypeParamBound::Trait(trait_bound)
}
TypeParamBound::Lifetime(lifetime) => TypeParamBound::Lifetime(lifetime),
}
}