use proc_macro2::{Group, Ident, TokenStream, TokenTree}; pub use self::index_path::IndexPath; #[allow(clippy::module_name_repetitions)] pub trait TokenStreamExt { fn find_all_ident(&self, target_ident: &Ident) -> Vec; fn replace_tokens( &self, index_paths_to_replace: &[IndexPath], substitution: &TokenTree, ) -> TokenStream; } impl TokenStreamExt for TokenStream { fn find_all_ident(&self, target_ident: &Ident) -> Vec { let mut found_indices = Vec::new(); recurse_find_all_ident( self.clone(), target_ident, &mut found_indices, &IndexPath::new(), ); found_indices } fn replace_tokens( &self, index_paths_to_replace: &[IndexPath], substitution: &TokenTree, ) -> TokenStream { self.clone() .into_iter() .enumerate() .map(|(index, mut token_tree)| { for index_path in index_paths_to_replace .iter() .filter(|path| path.indices()[0] == index) { token_tree = match token_tree { TokenTree::Ident(_) => substitution.clone(), TokenTree::Group(group) => TokenTree::Group(Group::new( group.delimiter(), recurse_replace_tokens( group.stream(), substitution, &index_path.indices()[1..], ), )), tt => tt, } } token_tree }) .collect() } } fn recurse_find_all_ident( input_stream: TokenStream, target_ident: &Ident, found_indices: &mut Vec, current_index_path: &IndexPath, ) { for (index, token_tree) in input_stream.into_iter().enumerate() { match token_tree { TokenTree::Ident(ident) if &ident == target_ident => { let mut index_path = current_index_path.clone(); index_path.push_index(index); found_indices.push(index_path); } TokenTree::Group(group) => { let mut index_path = current_index_path.clone(); index_path.push_index(index); recurse_find_all_ident( group.stream(), target_ident, found_indices, &index_path, ); } _ => {} } } } fn recurse_replace_tokens( token_stream: TokenStream, substitution: &TokenTree, indices: &[usize], ) -> TokenStream { token_stream .into_iter() .enumerate() .map(|(index, token_tree)| match token_tree { TokenTree::Ident(_) if index == indices[0] => substitution.clone(), TokenTree::Group(group) if index == indices[0] => { TokenTree::Group(Group::new( group.delimiter(), recurse_replace_tokens(group.stream(), substitution, &indices[1..]), )) } tt => tt, }) .collect() } mod index_path { #[derive(Debug, Clone, PartialEq, Eq)] pub struct IndexPath { indices: Vec, } impl IndexPath { pub fn new() -> Self { Self { indices: Vec::new(), } } pub fn push_index(&mut self, index: usize) { self.indices.push(index); } pub fn indices(&self) -> &[usize] { &self.indices } } impl From for IndexPath where IntoIter: IntoIterator, { fn from(value: IntoIter) -> Self { Self { indices: value.into_iter().collect(), } } } } #[cfg(test)] mod tests { use proc_macro2::Span; use quote::quote; use super::*; #[test] fn find_all_ident_works() { assert_eq!( quote! { let abc = xyz; } .find_all_ident(&Ident::new("xyz", Span::call_site())), vec![IndexPath::from([3])] ); assert_eq!( quote! { let abc = (xyz, "123"); } .find_all_ident(&Ident::new("xyz", Span::call_site())), vec![IndexPath::from([3, 0])] ); assert_eq!( quote! { return ("123", (yo, 180, xyz)); } .find_all_ident(&Ident::new("xyz", Span::call_site())), vec![IndexPath::from([1, 2, 4])] ); } #[test] fn find_all_ident_works_with_multiple() { assert_eq!( quote! { unsafe { functions::xyz = FunctionPtr::new_initialized( get_proc_addr(stringify!(xyz)) ); } } .find_all_ident(&Ident::new("xyz", Span::call_site()),), vec![IndexPath::from([1, 3]), IndexPath::from([1, 9, 1, 2, 0])] ); } #[test] fn recurse_replace_tokens_works() { assert_eq!( recurse_replace_tokens( quote! { let abc = xyz; }, &TokenTree::Ident(Ident::new("foo", Span::call_site())), &[3] ) .to_string(), quote! { let abc = foo; } .to_string() ); assert_eq!( recurse_replace_tokens( quote! { let abc = (xyz, "123"); }, &TokenTree::Ident(Ident::new("foo", Span::call_site())), &[3, 0] ) .to_string(), quote! { let abc = (foo, "123"); } .to_string() ); assert_eq!( recurse_replace_tokens( quote! { let abc = (hello, "123").iter_map(|_| xyz); }, &TokenTree::Ident(Ident::new("foo", Span::call_site())), &[6, 3] ) .to_string(), quote! { let abc = (hello, "123").iter_map(|_| foo); } .to_string() ); } }