From 1019924a29527eba2c8ec8bd976ece6ed76075b0 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sun, 25 Feb 2024 23:25:03 +0100 Subject: feat(ecs): add support for multiple system queries & local components --- Cargo.lock | 18 ++++ ecs/Cargo.toml | 6 ++ ecs/build.rs | 98 +++++++++++++++++ ecs/examples/multiple_queries.rs | 87 +++++++++++++++ ecs/examples/with_local.rs | 10 +- ecs/src/component.rs | 54 +++++++++- ecs/src/lib.rs | 73 ++++++++++++- ecs/src/system.rs | 226 +++++++++++++++++++++++++++++++++------ ecs/src/system/stateful.rs | 203 +++++++++++++++++++++++++---------- ecs/src/system/util.rs | 13 +++ 10 files changed, 690 insertions(+), 98 deletions(-) create mode 100644 ecs/build.rs create mode 100644 ecs/examples/multiple_queries.rs create mode 100644 ecs/src/system/util.rs diff --git a/Cargo.lock b/Cargo.lock index 884fa3f..0498fd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -122,9 +122,18 @@ dependencies = [ name = "ecs" version = "0.1.0" dependencies = [ + "itertools", + "proc-macro2", + "quote", "seq-macro", ] +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "engine" version = "0.1.0" @@ -216,6 +225,15 @@ dependencies = [ "png", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "khronos_api" version = "3.1.0" diff --git a/ecs/Cargo.toml b/ecs/Cargo.toml index cd25673..8cfdb71 100644 --- a/ecs/Cargo.toml +++ b/ecs/Cargo.toml @@ -5,3 +5,9 @@ edition = "2021" [dependencies] seq-macro = "0.3.5" + +[build-dependencies] +quote = "1.0.35" +proc-macro2 = "1.0.78" +itertools = "0.12.1" + diff --git a/ecs/build.rs b/ecs/build.rs new file mode 100644 index 0000000..1fbde2a --- /dev/null +++ b/ecs/build.rs @@ -0,0 +1,98 @@ +use std::collections::HashSet; +use std::path::PathBuf; + +use itertools::Itertools; +use proc_macro2::{Delimiter, Group, Ident}; +use quote::{format_ident, quote, ToTokens}; + +const CNT: usize = 4; + +fn main() +{ + let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); + + let impls = (0..CNT).flat_map(create_input_filter_impls).join("\n"); + + std::fs::write(out_dir.join("system_input_impls.rs"), impls).unwrap(); +} + +fn create_input_filter_impls(cnt: usize) -> Vec +{ + let mut present = HashSet::new(); + + let elements = (0..cnt) + .map(|_| ElementKind::Element) + .chain(vec![ElementKind::Excluded; cnt]) + .permutations(cnt) + .filter(|combination| { + if present.contains(combination) { + return false; + } + + present.insert(combination.clone()); + + true + }) + .map(|elements| { + elements + .into_iter() + .enumerate() + .map(|(index, element)| match element { + ElementKind::Element => { + IdentOrTuple::Ident(format_ident!("Elem{index}")) + } + ElementKind::Excluded => IdentOrTuple::Tuple, + }) + .collect::>() + }) + .collect::>(); + + elements + .into_iter() + .map(create_single_input_filter_impl) + .collect() +} + +fn create_single_input_filter_impl( + elements: Vec, +) -> proc_macro2::TokenStream +{ + let ident_elements = elements + .iter() + .filter(|element| matches!(element, IdentOrTuple::Ident(_))) + .collect::>(); + + quote! { + impl<#(#ident_elements: Input,)*> InputFilter for (#(#elements,)*) { + type Filtered = (#(#ident_elements,)*); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum ElementKind +{ + Element, + Excluded, +} + +#[derive(Debug)] +enum IdentOrTuple +{ + Ident(Ident), + Tuple, +} + +impl ToTokens for IdentOrTuple +{ + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) + { + match self { + Self::Ident(ident) => ident.to_tokens(tokens), + Self::Tuple => { + Group::new(Delimiter::Parenthesis, proc_macro2::TokenStream::new()) + .to_tokens(tokens) + } + } + } +} diff --git a/ecs/examples/multiple_queries.rs b/ecs/examples/multiple_queries.rs new file mode 100644 index 0000000..a4a5d2d --- /dev/null +++ b/ecs/examples/multiple_queries.rs @@ -0,0 +1,87 @@ +use std::fmt::Display; + +use ecs::{Query, World}; + +struct Health +{ + health: u32, +} + +enum AttackStrength +{ + Strong, + Weak, +} + +struct EnemyName +{ + name: String, +} + +impl Display for EnemyName +{ + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + self.name.fmt(formatter) + } +} + +fn say_hello( + mut query: Query<(AttackStrength,)>, + mut enemy_query: Query<(Health, EnemyName)>, +) +{ + for (attack_strength,) in query.iter_mut() { + for (health, enemy_name) in enemy_query.iter_mut() { + let damage = match attack_strength { + AttackStrength::Strong => 20, + AttackStrength::Weak => 10, + }; + + if health.health <= damage { + println!("Enemy '{enemy_name}' died"); + + health.health = 0; + + continue; + } + + health.health -= damage; + + println!("Enemy '{enemy_name}' took {damage} damage"); + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +enum Event +{ + Start, +} + +fn main() +{ + let mut world = World::::new(); + + world.register_system(Event::Start, say_hello); + + world.create_entity(( + Health { health: 100 }, + EnemyName { name: "Big spider".to_string() }, + )); + + world.create_entity(( + Health { health: 30 }, + EnemyName { name: "Small goblin".to_string() }, + )); + + world.create_entity(( + Health { health: 30 }, + EnemyName { name: "Headcrab".to_string() }, + )); + + world.create_entity((AttackStrength::Strong,)); + world.create_entity((AttackStrength::Weak,)); + + world.emit(&Event::Start); +} diff --git a/ecs/examples/with_local.rs b/ecs/examples/with_local.rs index d7af0e0..0bd8f66 100644 --- a/ecs/examples/with_local.rs +++ b/ecs/examples/with_local.rs @@ -1,5 +1,5 @@ use ecs::component::Local; -use ecs::system::{Into, System}; +use ecs::system::{Input as SystemInput, Into, System}; use ecs::{Query, World}; struct SomeData @@ -17,6 +17,8 @@ struct SayHelloState cnt: usize, } +impl SystemInput for SayHelloState {} + fn say_hello(mut query: Query<(SomeData, String)>, mut state: Local) { for (data, text) in query.iter_mut() { @@ -50,14 +52,16 @@ fn main() world.register_system( Event::Update, - say_hello.into_system().initialize(SayHelloState { cnt: 0 }), + say_hello + .into_system() + .initialize((SayHelloState { cnt: 0 },)), ); world.register_system( Event::Update, say_whats_up .into_system() - .initialize(SayHelloState { cnt: 0 }), + .initialize((SayHelloState { cnt: 0 },)), ); world.create_entity(( diff --git a/ecs/src/component.rs b/ecs/src/component.rs index bead3b5..59b737e 100644 --- a/ecs/src/component.rs +++ b/ecs/src/component.rs @@ -4,6 +4,9 @@ use std::ops::{Deref, DerefMut}; use seq_macro::seq; +use crate::system::{Input as SystemInput, Param as SystemParam, System}; +use crate::ComponentStorage; + pub trait Component: Any { #[doc(hidden)] @@ -114,20 +117,65 @@ seq!(C in 0..=64 { /// Holds a component which is local to a single system. #[derive(Debug)] -pub struct Local<'world, Value> +pub struct Local<'world, Value: SystemInput> { value: &'world mut Value, } impl<'world, Value> Local<'world, Value> +where + Value: SystemInput, { - pub(crate) fn new(value: &'world mut Value) -> Self + fn new(value: &'world mut Value) -> Self { Self { value } } } +unsafe impl<'world, Value: 'static> SystemParam<'world> for Local<'world, Value> +where + Value: SystemInput, +{ + type Flags = (); + type Input = Value; + + fn initialize(system: &mut impl System, input: Self::Input) + { + system.set_local_component(input); + } + + fn new( + system: &'world mut impl System, + _component_storage: &'world mut ComponentStorage, + ) -> Self + { + let local_component = system + .get_local_component_mut::() + .expect("Local component is uninitialized"); + + Self::new(local_component) + } + + fn is_compatible>() -> bool + { + let other_comparable = Other::get_comparable(); + + let Some(other_type_id) = other_comparable.downcast_ref::() else { + return true; + }; + + TypeId::of::() != *other_type_id + } + + fn get_comparable() -> Box + { + Box::new(TypeId::of::()) + } +} + impl<'world, Value> Deref for Local<'world, Value> +where + Value: SystemInput, { type Target = Value; @@ -138,6 +186,8 @@ impl<'world, Value> Deref for Local<'world, Value> } impl<'world, Value> DerefMut for Local<'world, Value> +where + Value: SystemInput, { fn deref_mut(&mut self) -> &mut Self::Target { diff --git a/ecs/src/lib.rs b/ecs/src/lib.rs index df46a5a..84009e0 100644 --- a/ecs/src/lib.rs +++ b/ecs/src/lib.rs @@ -1,6 +1,6 @@ #![deny(clippy::all, clippy::pedantic)] -use std::any::TypeId; +use std::any::{Any, TypeId}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::Hash; @@ -8,7 +8,12 @@ use std::marker::PhantomData; use std::slice::IterMut as SliceIterMut; use crate::component::{Component, Sequence as ComponentSequence}; -use crate::system::{System, TypeErased as TypeErasedSystem}; +use crate::system::{ + NoInitParamFlag as NoInitSystemParamFlag, + Param as SystemParam, + System, + TypeErased as TypeErasedSystem, +}; pub mod component; pub mod system; @@ -142,6 +147,70 @@ where } } +unsafe impl<'world, Comps> SystemParam<'world> for Query<'world, Comps> +where + Comps: ComponentSequence, +{ + type Flags = NoInitSystemParamFlag; + type Input = (); + + fn initialize(_system: &mut impl System, _input: Self::Input) + { + } + + fn new( + _system: &'world mut impl System, + component_storage: &'world mut ComponentStorage, + ) -> Self + { + Self::new(component_storage) + } + + fn is_compatible>() -> bool + { + let other_comparable = Other::get_comparable(); + + let Some(other_query_component_ids) = + other_comparable.downcast_ref::() + else { + return true; + }; + + !other_query_component_ids.contains_component_in::() + } + + fn get_comparable() -> Box + { + Box::new(QueryComponentIds { + component_type_ids: Comps::type_ids(), + }) + } +} + +#[derive(Debug)] +struct QueryComponentIds +{ + component_type_ids: Vec, +} + +impl QueryComponentIds +{ + fn contains_component_in(&self) -> bool + where + OtherComps: ComponentSequence, + { + let other_component_type_ids = OtherComps::type_ids() + .into_iter() + .collect::>(); + + // TODO: Make this a bit smarter. Queries with a same component can be compatible + // if one of the queries have a component the other one does not have + self.component_type_ids + .iter() + .any(|component_type_id| other_component_type_ids.contains(component_type_id)) + } +} + pub struct QueryComponentIter<'world, Comps> { entity_iter: SliceIterMut<'world, Entity>, diff --git a/ecs/src/system.rs b/ecs/src/system.rs index 22e2215..d90e0a2 100644 --- a/ecs/src/system.rs +++ b/ecs/src/system.rs @@ -1,16 +1,21 @@ -use std::any::Any; +use std::any::{Any, TypeId}; use std::convert::Infallible; use std::fmt::Debug; +use std::mem::{transmute_copy, ManuallyDrop}; +use std::ptr::addr_of_mut; -use crate::component::Sequence as ComponentSequence; -use crate::{ComponentStorage, Query}; +use seq_macro::seq; + +use crate::component::Component; +use crate::system::util::check_params_are_compatible; +use crate::ComponentStorage; pub mod stateful; +mod util; + pub trait System: 'static { - type Query<'a>; - type Input; #[must_use] @@ -19,39 +24,90 @@ pub trait System: 'static fn run(&mut self, component_storage: &mut ComponentStorage); fn into_type_erased(self) -> TypeErased; -} -impl System)> for Func -where - Func: Fn(Query) + 'static, - Comps: ComponentSequence, -{ - type Input = Infallible; - type Query<'a> = Query<'a, Comps>; + fn get_local_component_mut( + &mut self, + ) -> Option<&mut LocalComponent>; - fn initialize(self, _input: Self::Input) -> Self - { - self - } - - fn run(&mut self, component_storage: &mut ComponentStorage) - { - self(Query::new(component_storage)); - } + fn set_local_component( + &mut self, + local_component: LocalComponent, + ); +} - fn into_type_erased(self) -> TypeErased - { - TypeErased { - data: Box::new(self), - func: Box::new(|data, component_storage| { - let me = data.downcast_mut::().unwrap(); - - me.run(component_storage); - }), - } - } +macro_rules! impl_system { + ($c: tt) => { + seq!(I in 0..=$c { + impl<'world, Func, #(TParam~I,)*> System + for Func + where + Func: Fn(#(TParam~I,)*) + Copy + 'static, + #(TParam~I: Param<'world, Flags = NoInitParamFlag>,)* + { + type Input = Infallible; + + fn initialize(self, _input: Self::Input) -> Self + { + self + } + + fn run(&mut self, component_storage: &mut ComponentStorage) + { + #( + check_params_are_compatible!(I, TParam~I, $c); + )* + + let func = *self; + + func(#({ + // SAFETY: All parameters are compatible so this is fine + let this = unsafe { + &mut *addr_of_mut!(*self) + }; + + // SAFETY: All parameters are compatible so this is fine + let component_storage = unsafe { + &mut *addr_of_mut!(*component_storage) + }; + + TParam~I::new(this, component_storage) + },)*); + } + + fn into_type_erased(self) -> TypeErased + { + TypeErased { + data: Box::new(self), + func: Box::new(|data, component_storage| { + let me = data.downcast_mut::().unwrap(); + + me.run(component_storage); + }), + } + } + + fn get_local_component_mut( + &mut self, + ) -> Option<&mut LocalComponent> + { + panic!("System does not have any local components"); + } + + fn set_local_component( + &mut self, + _local_component: LocalComponent, + ) { + panic!("System does not have any local components"); + } + } + }); + }; } +seq!(C in 0..=4 { + impl_system!(C); +}); + pub trait Into { type System; @@ -83,3 +139,107 @@ impl Debug for TypeErased /// Function in [`TypeErased`] used to run the system. type TypeErasedFunc = dyn Fn(&mut dyn Any, &mut ComponentStorage); + +/// A parameter to a [`System`]. +/// +/// # Safety +/// The `is_compatible` function is used for safety so it must be implemented properly. +pub unsafe trait Param<'world> +{ + type Input; + type Flags; + + fn initialize(system: &mut impl System, input: Self::Input); + + fn new( + system: &'world mut impl System, + component_storage: &'world mut ComponentStorage, + ) -> Self; + + fn is_compatible>() -> bool; + + fn get_comparable() -> Box; +} + +pub struct NoInitParamFlag {} + +/// A type which can be used as input to a [`System`]. +pub trait Input: 'static {} + +pub trait InputFilter +{ + type Filtered: FilteredInputs; +} + +pub trait FilteredInputs +{ + type InOptions: OptionInputs; + + fn into_in_options(self) -> Self::InOptions; +} + +macro_rules! impl_filtered_inputs { + ($cnt: tt) => { + seq!(I in 0..$cnt { + impl<#(Input~I: Input,)*> FilteredInputs for (#(Input~I,)*) { + type InOptions = (#(Option,)*); + + fn into_in_options(self) -> Self::InOptions { + #![allow(clippy::unused_unit)] + (#(Some(self.I),)*) + } + } + }); + }; +} + +seq!(N in 0..4 { + impl_filtered_inputs!(N); +}); + +pub trait OptionInputs +{ + fn take(&mut self) -> TakeOptionInputResult; +} + +macro_rules! impl_option_inputs { + ($cnt: tt) => { + seq!(I in 0..$cnt { + impl<#(Input~I: 'static,)*> OptionInputs for (#(Option,)*) { + fn take(&mut self) -> TakeOptionInputResult { + #( + if TypeId::of::() == TypeId::of::() { + let input = match self.I.take() { + Some(input) => ManuallyDrop::new(input), + None => { + return TakeOptionInputResult::AlreadyTaken; + } + }; + + return TakeOptionInputResult::Found( + // SAFETY: It can be transmuted safely since it is the + // same type and the type is 'static + unsafe { transmute_copy(&input) } + ); + } + )* + + TakeOptionInputResult::NotFound + } + } + }); + }; +} + +seq!(N in 0..4 { + impl_option_inputs!(N); +}); + +pub enum TakeOptionInputResult +{ + Found(Input), + NotFound, + AlreadyTaken, +} + +include!(concat!(env!("OUT_DIR"), "/system_input_impls.rs")); diff --git a/ecs/src/system/stateful.rs b/ecs/src/system/stateful.rs index b641cf2..9b2f279 100644 --- a/ecs/src/system/stateful.rs +++ b/ecs/src/system/stateful.rs @@ -1,67 +1,154 @@ -use std::marker::PhantomData; +use std::any::{type_name, TypeId}; +use std::collections::HashMap; +use std::ptr::addr_of_mut; -use crate::component::Local; -use crate::system::{System, TypeErased}; -use crate::{ComponentStorage, Query}; +use seq_macro::seq; + +use crate::component::Component; +use crate::system::util::check_params_are_compatible; +use crate::system::{ + FilteredInputs, + InputFilter, + Into as IntoSystem, + OptionInputs, + Param, + System, + TakeOptionInputResult, + TypeErased, +}; +use crate::ComponentStorage; /// A stateful system. -pub struct Stateful +pub struct Stateful { func: Func, - local_component: Option, - _comps_pd: PhantomData, + local_components: HashMap>, } -impl - System, Local)> - for Stateful -where - Func: Fn(Query, Local), -{ - type Input = LocalComponent; - type Query<'a> = Query<'a, Comps>; - - fn initialize(mut self, input: Self::Input) -> Self - { - self.local_component = Some(input); - - self - } - - fn run(&mut self, component_storage: &mut ComponentStorage) - { - (self.func)( - Query::new(component_storage), - Local::new(self.local_component.as_mut().unwrap()), - ); - } - - fn into_type_erased(self) -> TypeErased - { - TypeErased { - data: Box::new(self), - func: Box::new(move |data, component_storage| { - let this = data.downcast_mut::().unwrap(); - - this.run(component_storage); - }), - } - } -} +macro_rules! impl_system { + ($c: tt) => { + seq!(I in 0..=$c { + impl<'world, Func, #(TParam~I,)*> System + for Stateful + where + Func: Fn(#(TParam~I,)*) + Copy + 'static, + #(TParam~I: Param<'world>,)* + #(TParam~I::Input: 'static,)* + (#(TParam~I::Input,)*): InputFilter + { + type Input = <(#(TParam~I::Input,)*) as InputFilter>::Filtered; -impl - crate::system::Into, Local)> for Func -where - Func: Fn(Query, Local), -{ - type System = Stateful; - - fn into_system(self) -> Self::System - { - Self::System { - func: self, - local_component: None, - _comps_pd: PhantomData, - } - } + fn initialize(mut self, input: Self::Input) -> Self + { + let mut option_input = input.into_in_options(); + + #( + if TypeId::of::() != TypeId::of::<()>() { + let input = match option_input.take::() { + TakeOptionInputResult::Found(input) => input, + TakeOptionInputResult::NotFound => { + panic!( + "Parameter input {} not found", + type_name::() + ); + } + TakeOptionInputResult::AlreadyTaken => { + panic!( + concat!( + "Parameter {} is already initialized. ", + "System cannot contain multiple inputs with ", + "the same type", + ), + type_name::() + ); + + } + }; + + TParam~I::initialize( + &mut self, + input + ); + } + )* + + self + } + + fn run(&mut self, component_storage: &mut ComponentStorage) + { + #( + check_params_are_compatible!(I, TParam~I, $c); + )* + + let func = self.func; + + func(#({ + // SAFETY: All parameters are compatible so this is fine + let this = unsafe { + &mut *addr_of_mut!(*self) + }; + + // SAFETY: All parameters are compatible so this is fine + let component_storage = unsafe { + &mut *addr_of_mut!(*component_storage) + }; + + TParam~I::new(this, component_storage) + },)*); + } + + fn into_type_erased(self) -> TypeErased + { + TypeErased { + data: Box::new(self), + func: Box::new(|data, component_storage| { + let me = data.downcast_mut::().unwrap(); + + me.run(component_storage); + }), + } + } + + fn get_local_component_mut( + &mut self, + ) -> Option<&mut LocalComponent> + { + self.local_components + .get_mut(&TypeId::of::())? + .downcast_mut() + } + + fn set_local_component( + &mut self, + local_component: LocalComponent, + ) + { + self.local_components + .insert(TypeId::of::(), + Box::new(local_component)); + } + } + + impl IntoSystem + for Func + where + Func: Fn(#(TParam~I,)*) + Copy + 'static, + { + type System = Stateful; + + fn into_system(self) -> Self::System + { + Self::System { + func: self, + local_components: HashMap::new(), + } + } + } + }); + }; } + +seq!(C in 0..4 { + impl_system!(C); +}); diff --git a/ecs/src/system/util.rs b/ecs/src/system/util.rs new file mode 100644 index 0000000..9d04f1d --- /dev/null +++ b/ecs/src/system/util.rs @@ -0,0 +1,13 @@ +macro_rules! check_params_are_compatible { + ($excluded_index: tt, $param: ident, $cnt: tt) => { + seq!(N in 0..=$cnt { + if N != $excluded_index { + if !$param::is_compatible::() { + panic!("Atleast two parameters are incompatible"); + } + } + }) + }; +} + +pub(crate) use check_params_are_compatible; -- cgit v1.2.3-18-g5258