//! Macros for Ridicule, a mocking library supporting non-static generics. #![deny(clippy::all, clippy::pedantic, missing_docs)] use proc_macro::TokenStream; use proc_macro2::Ident; use proc_macro_error::{proc_macro_error, ResultExt}; use quote::{format_ident, quote}; use syn::token::Brace; use syn::{ parse, Block, FnArg, GenericArgument, GenericParam, Generics, ImplItem, ImplItemMethod, ItemTrait, Path, PathArguments, PathSegment, ReturnType, TraitItem, Type, TypeBareFn, TypeParamBound, Visibility, WherePredicate, }; use crate::expectation::Expectation; use crate::mock::Mock; use crate::mock_input::MockInput; use crate::syn_ext::{PathExt, PathSegmentExt, WithLeadingColons}; use crate::util::create_path; mod expectation; mod mock; mod mock_input; mod syn_ext; mod util; /// Creates a mock. /// /// # Examples /// ``` /// use ridicule::mock; /// /// trait Foo /// { /// fn bar(&self, a: A) -> B; /// } /// /// mock! { /// MockFoo {} /// /// impl Foo for MockFoo /// { /// fn bar(&self, a: A) -> B; /// } /// } /// /// fn main() /// { /// let mut mock_foo = MockFoo::new(); /// /// mock_foo /// .expect_bar() /// .returning(|foo, a: u32| format!("Hello {a}")); /// /// assert_eq!(mock_foo.bar::(123), "Hello 123"); /// } /// ``` #[proc_macro] #[proc_macro_error] pub fn mock(input_stream: TokenStream) -> TokenStream { let input = parse::(input_stream.clone()).unwrap_or_abort(); let mock_ident = input.mock; let mock_mod_ident = format_ident!("__{mock_ident}"); let method_items = get_type_replaced_impl_item_methods(input.item_impl.items, &mock_ident); let mock = Mock::new( mock_ident.clone(), input.mocked_trait, &method_items, input.item_impl.generics.clone(), ); let expectations = method_items.iter().map(|item_method| { Expectation::new( &mock_ident, item_method, input.item_impl.generics.params.clone(), ) }); quote! { mod #mock_mod_ident { use super::*; #mock #(#expectations)* } use #mock_mod_ident::#mock_ident; } .into() } /// Creates a mock automatically. #[proc_macro_attribute] #[proc_macro_error] pub fn automock(_: TokenStream, input_stream: TokenStream) -> TokenStream { let item_trait = parse::(input_stream).unwrap_or_abort(); let mock_ident = format_ident!("Mock{}", item_trait.ident); let mock_mod_ident = format_ident!("__{mock_ident}"); let method_items = get_type_replaced_impl_item_methods( item_trait.items.iter().filter_map(|item| match item { TraitItem::Method(item_method) => Some(ImplItem::Method(ImplItemMethod { attrs: item_method.attrs.clone(), vis: Visibility::Inherited, defaultness: None, sig: item_method.sig.clone(), block: Block { brace_token: Brace::default(), stmts: vec![], }, })), _ => None, }), &mock_ident, ); let mock = Mock::new( mock_ident.clone(), Path::new( WithLeadingColons::No, [PathSegment::new(item_trait.ident.clone(), None)], ), &method_items, item_trait.generics.clone(), ); let expectations = method_items.iter().map(|item_method| { Expectation::new(&mock_ident, item_method, item_trait.generics.params.clone()) }); let visibility = &item_trait.vis; quote! { #item_trait mod #mock_mod_ident { use super::*; #mock #(#expectations)* } #visibility use #mock_mod_ident::#mock_ident; } .into() } fn get_type_replaced_impl_item_methods( impl_items: impl IntoIterator, mock_ident: &Ident, ) -> Vec { let target_path = create_path!(Self); let replacement_path = Path::new( WithLeadingColons::No, [PathSegment::new(mock_ident.clone(), None)], ); impl_items .into_iter() .filter_map(|item| match item { ImplItem::Method(mut item_method) => { item_method.sig.inputs = item_method .sig .inputs .into_iter() .map(|fn_arg| match fn_arg { FnArg::Typed(mut typed_arg) => { typed_arg.ty = Box::new(replace_path_in_type( *typed_arg.ty, &target_path, &replacement_path, )); FnArg::Typed(typed_arg) } FnArg::Receiver(receiver) => FnArg::Receiver(receiver), }) .collect(); item_method.sig.output = match item_method.sig.output { ReturnType::Type(r_arrow, return_type) => ReturnType::Type( r_arrow, Box::new(replace_path_in_type( *return_type, &target_path, &replacement_path, )), ), ReturnType::Default => ReturnType::Default, }; item_method.sig.generics = replace_path_in_generics( item_method.sig.generics, &target_path, &replacement_path, ); Some(item_method) } _ => None, }) .collect() } fn replace_path_in_generics( mut generics: Generics, target_path: &Path, replacement_path: &Path, ) -> Generics { generics.params = generics .params .into_iter() .map(|generic_param| match generic_param { GenericParam::Type(mut type_param) => { type_param.bounds = type_param .bounds .into_iter() .map(|bound| { replace_type_param_bound_paths( bound, target_path, replacement_path, ) }) .collect(); GenericParam::Type(type_param) } generic_param => generic_param, }) .collect(); generics.where_clause = generics.where_clause.map(|mut where_clause| { where_clause.predicates = where_clause .predicates .into_iter() .map(|predicate| match predicate { WherePredicate::Type(mut predicate_type) => { predicate_type.bounded_ty = replace_path_in_type( predicate_type.bounded_ty, target_path, replacement_path, ); predicate_type.bounds = predicate_type .bounds .into_iter() .map(|bound| { replace_type_param_bound_paths( bound, target_path, replacement_path, ) }) .collect(); WherePredicate::Type(predicate_type) } predicate => predicate, }) .collect(); where_clause }); generics } fn replace_path_in_type(ty: Type, target_path: &Path, replacement_path: &Path) -> Type { match ty { Type::Ptr(mut type_ptr) => { type_ptr.elem = Box::new(replace_path_in_type( *type_ptr.elem, target_path, replacement_path, )); Type::Ptr(type_ptr) } Type::Path(mut type_path) => { if &type_path.path == target_path { type_path.path = replacement_path.clone(); } else { type_path.path = replace_path_args(type_path.path, target_path, replacement_path); } Type::Path(type_path) } Type::Array(mut type_array) => { type_array.elem = Box::new(replace_path_in_type( *type_array.elem, target_path, replacement_path, )); Type::Array(type_array) } Type::Group(mut type_group) => { type_group.elem = Box::new(replace_path_in_type( *type_group.elem, target_path, replacement_path, )); Type::Group(type_group) } Type::BareFn(type_bare_fn) => Type::BareFn(replace_type_bare_fn_type_paths( type_bare_fn, target_path, replacement_path, )), Type::Paren(mut type_paren) => { type_paren.elem = Box::new(replace_path_in_type( *type_paren.elem, target_path, replacement_path, )); Type::Paren(type_paren) } Type::Slice(mut type_slice) => { type_slice.elem = Box::new(replace_path_in_type( *type_slice.elem, target_path, replacement_path, )); Type::Slice(type_slice) } Type::Tuple(mut type_tuple) => { type_tuple.elems = type_tuple .elems .into_iter() .map(|elem_type| { replace_path_in_type(elem_type, target_path, replacement_path) }) .collect(); Type::Tuple(type_tuple) } Type::Reference(mut type_reference) => { type_reference.elem = Box::new(replace_path_in_type( *type_reference.elem, target_path, replacement_path, )); Type::Reference(type_reference) } Type::TraitObject(mut type_trait_object) => { type_trait_object.bounds = type_trait_object .bounds .into_iter() .map(|bound| match bound { TypeParamBound::Trait(mut trait_bound) => { trait_bound.path = replace_path_args( trait_bound.path, target_path, replacement_path, ); TypeParamBound::Trait(trait_bound) } TypeParamBound::Lifetime(lifetime) => { TypeParamBound::Lifetime(lifetime) } }) .collect(); Type::TraitObject(type_trait_object) } other_type => other_type, } } fn replace_path_args(mut path: Path, target_path: &Path, replacement_path: &Path) -> Path { path.segments = path .segments .into_iter() .map(|mut segment| { segment.arguments = match segment.arguments { PathArguments::AngleBracketed(mut generic_args) => { generic_args.args = generic_args .args .into_iter() .map(|generic_arg| match generic_arg { GenericArgument::Type(ty) => GenericArgument::Type( replace_path_in_type(ty, target_path, replacement_path), ), GenericArgument::Binding(mut binding) => { binding.ty = replace_path_in_type( binding.ty, target_path, replacement_path, ); GenericArgument::Binding(binding) } generic_arg => generic_arg, }) .collect(); PathArguments::AngleBracketed(generic_args) } PathArguments::Parenthesized(mut generic_args) => { generic_args.inputs = generic_args .inputs .into_iter() .map(|input_ty| { replace_path_in_type(input_ty, target_path, replacement_path) }) .collect(); generic_args.output = match generic_args.output { ReturnType::Type(r_arrow, return_type) => ReturnType::Type( r_arrow, Box::new(replace_path_in_type( *return_type, target_path, replacement_path, )), ), ReturnType::Default => ReturnType::Default, }; PathArguments::Parenthesized(generic_args) } PathArguments::None => PathArguments::None, }; segment }) .collect(); path } fn replace_type_bare_fn_type_paths( mut type_bare_fn: TypeBareFn, target_path: &Path, replacement_path: &Path, ) -> TypeBareFn { type_bare_fn.inputs = type_bare_fn .inputs .into_iter() .map(|mut bare_fn_arg| { bare_fn_arg.ty = replace_path_in_type(bare_fn_arg.ty, target_path, replacement_path); bare_fn_arg }) .collect(); type_bare_fn.output = match type_bare_fn.output { ReturnType::Type(r_arrow, return_type) => ReturnType::Type( r_arrow, Box::new(replace_path_in_type( *return_type, target_path, replacement_path, )), ), ReturnType::Default => ReturnType::Default, }; type_bare_fn } fn replace_type_param_bound_paths( type_param_bound: TypeParamBound, target_path: &Path, replacement_path: &Path, ) -> TypeParamBound { match type_param_bound { TypeParamBound::Trait(mut trait_bound) => { if &trait_bound.path == target_path { trait_bound.path = replacement_path.clone(); } else { trait_bound.path = replace_path_args(trait_bound.path, target_path, replacement_path); } TypeParamBound::Trait(trait_bound) } TypeParamBound::Lifetime(lifetime) => TypeParamBound::Lifetime(lifetime), } }