From 9863b431950c681225f8774af244a56adbd18937 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sun, 19 May 2024 21:12:07 +0200 Subject: feat(ecs): add support for optional query components --- ecs/examples/optional_component.rs | 86 ++++++++++++++++++++++++++++++ ecs/src/component.rs | 106 +++++++++++++++++++++++++++++++++---- ecs/src/lib.rs | 18 +++++-- ecs/src/query.rs | 14 +++-- ecs/src/system.rs | 36 ++++++++++++- 5 files changed, 239 insertions(+), 21 deletions(-) create mode 100644 ecs/examples/optional_component.rs (limited to 'ecs') diff --git a/ecs/examples/optional_component.rs b/ecs/examples/optional_component.rs new file mode 100644 index 0000000..e47bf2e --- /dev/null +++ b/ecs/examples/optional_component.rs @@ -0,0 +1,86 @@ +use ecs::event::Event; +use ecs::{Component, Query, World}; + +#[derive(Debug, Component)] +struct PettingCapacity +{ + capacity_left: u32, +} + +#[derive(Debug, Clone, Copy, Component)] +enum Aggressivity +{ + High, + Medium, + Low, +} + +#[derive(Debug, Component)] +pub struct CatName +{ + name: String, +} + +fn pet_cats(query: Query<(CatName, PettingCapacity, Option)>) +{ + for (cat_name, mut petting_capacity, aggressivity) in &query { + let Some(aggressivity) = aggressivity else { + println!("Aggressivity of cat {} is unknown. Skipping", cat_name.name); + continue; + }; + + if let Aggressivity::High = *aggressivity { + println!("Cat {} is aggressive. Skipping", cat_name.name); + continue; + } + + if petting_capacity.capacity_left == 0 { + println!( + "Cat {} have had enough of being petted. Skipping", + cat_name.name + ); + continue; + } + + println!("Petting cat {}", cat_name.name); + + petting_capacity.capacity_left -= 1; + } +} + +#[derive(Debug)] +struct PettingTime; + +impl Event for PettingTime {} + +fn main() +{ + let mut world = World::new(); + + world.register_system(PettingTime, pet_cats); + + world.create_entity(( + CatName { name: "Jasper".to_string() }, + Aggressivity::Medium, + PettingCapacity { capacity_left: 5 }, + )); + + world.create_entity(( + CatName { name: "Otto".to_string() }, + PettingCapacity { capacity_left: 9 }, + )); + + world.create_entity(( + CatName { name: "Carrie".to_string() }, + PettingCapacity { capacity_left: 2 }, + Aggressivity::High, + )); + + world.create_entity(( + CatName { name: "Tommy".to_string() }, + PettingCapacity { capacity_left: 1 }, + Aggressivity::Low, + )); + + world.emit(PettingTime); +} diff --git a/ecs/src/component.rs b/ecs/src/component.rs index 7a61f39..604f54d 100644 --- a/ecs/src/component.rs +++ b/ecs/src/component.rs @@ -1,4 +1,4 @@ -use std::any::{Any, TypeId}; +use std::any::{type_name, Any, TypeId}; use std::fmt::Debug; use seq_macro::seq; @@ -11,6 +11,15 @@ pub mod local; pub trait Component: SystemInput + Any + TypeName { + /// The component type in question. Will usually be `Self` + type Component + where + Self: Sized; + + type RefMut<'component> + where + Self: Sized; + fn drop_last(&self) -> bool; #[doc(hidden)] @@ -54,6 +63,43 @@ impl TypeName for Box } } +impl Component for Option +where + ComponentT: Component, +{ + type Component = ComponentT; + type RefMut<'component> = Option>; + + fn drop_last(&self) -> bool + { + self.as_ref() + .map(|component| component.drop_last()) + .unwrap_or_default() + } + + fn as_any_mut(&mut self) -> &mut dyn Any + { + self + } + + fn as_any(&self) -> &dyn Any + { + self + } +} + +impl TypeName for Option +where + ComponentT: Component, +{ + fn type_name(&self) -> &'static str + { + type_name::() + } +} + +impl SystemInput for Option where ComponentT: Component {} + /// A sequence of components. pub trait Sequence { @@ -63,18 +109,60 @@ pub trait Sequence fn into_vec(self) -> Vec>; - fn type_ids() -> Vec; + fn type_ids() -> Vec<(TypeId, IsOptional)>; fn from_components<'component>( components: impl Iterator>>, ) -> Self::Refs<'component>; } +/// Whether or not a `Component` is optional. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IsOptional +{ + Yes, + No, +} + +impl From for IsOptional +{ + fn from(is_optional: bool) -> Self + { + if is_optional { + return IsOptional::Yes; + } + + IsOptional::No + } +} + +/// Returns whether the given component type is a optional component. +/// +/// Will return `true` if the component is a [`Option`]. +pub fn is_optional() -> bool +{ + if TypeId::of::() == TypeId::of::>() { + return true; + } + + false +} + +pub trait FromOptionalComponent<'comp> +{ + fn from_optional_component( + optional_component: Option>>, + ) -> Self; +} + macro_rules! inner { ($c: tt) => { seq!(I in 0..=$c { - impl<#(Comp~I: Component,)*> Sequence for (#(Comp~I,)*) { - type Refs<'component> = (#(ComponentRefMut<'component, Comp~I>,)*) + impl<#(Comp~I: Component,)*> Sequence for (#(Comp~I,)*) + where + #(for<'comp> Comp~I::RefMut<'comp>: FromOptionalComponent<'comp>,)* + { + type Refs<'component> = (#(Comp~I::RefMut<'component>,)*) where Self: 'component; fn into_vec(self) -> Vec> @@ -82,11 +170,11 @@ macro_rules! inner { Vec::from_iter([#(Box::new(self.I) as Box,)*]) } - fn type_ids() -> Vec + fn type_ids() -> Vec<(TypeId, IsOptional)> { vec![ #( - TypeId::of::(), + (TypeId::of::(), is_optional::().into()), )* ] } @@ -105,7 +193,7 @@ macro_rules! inner { .expect("Failed to acquire read-write component lock"); #( - if comp_ref.is::() { + if comp_ref.is::() { comp_~I = Some(comp_ref); continue; } @@ -113,9 +201,7 @@ macro_rules! inner { } (#( - ComponentRefMut::new( - comp_~I.unwrap(), - ), + Comp~I::RefMut::from_optional_component(comp_~I), )*) } } diff --git a/ecs/src/lib.rs b/ecs/src/lib.rs index 009ff21..c2fc9c7 100644 --- a/ecs/src/lib.rs +++ b/ecs/src/lib.rs @@ -10,7 +10,11 @@ use std::sync::Arc; use std::vec::Drain; use crate::actions::Action; -use crate::component::{Component, Sequence as ComponentSequence}; +use crate::component::{ + Component, + IsOptional as ComponentIsOptional, + Sequence as ComponentSequence, +}; use crate::event::{Event, Id as EventId, Ids, Sequence as EventSequence}; use crate::extension::{Collector as ExtensionCollector, Extension}; use crate::lock::Lock; @@ -281,7 +285,7 @@ impl ComponentStorage fn find_entity_with_components( &self, start_index: usize, - component_type_ids: &[TypeId], + component_type_ids: &[(TypeId, ComponentIsOptional)], ) -> Option<(usize, &Entity)> { // TODO: This is a really dumb and slow way to do this. Refactor the world @@ -294,9 +298,13 @@ impl ComponentStorage .map(|component| component.id) .collect::>(); - if component_type_ids.iter().all(|component_type_id| { - entity_components.contains(component_type_id) - }) { + if component_type_ids + .iter() + .filter(|(_, is_optional)| *is_optional == ComponentIsOptional::No) + .all(|(component_type_id, _)| { + entity_components.contains(component_type_id) + }) + { return true; } diff --git a/ecs/src/query.rs b/ecs/src/query.rs index 683dde7..61e2797 100644 --- a/ecs/src/query.rs +++ b/ecs/src/query.rs @@ -3,7 +3,10 @@ use std::collections::HashSet; use std::marker::PhantomData; use std::sync::{Arc, Weak}; -use crate::component::Sequence as ComponentSequence; +use crate::component::{ + IsOptional as ComponentIsOptional, + Sequence as ComponentSequence, +}; use crate::lock::{Lock, ReadGuard}; use crate::system::{ NoInitParamFlag as NoInitSystemParamFlag, @@ -184,7 +187,7 @@ pub struct ComponentIter<'world, Comps> { component_storage: &'world ComponentStorage, current_entity_index: usize, - component_type_ids: Vec, + component_type_ids: Vec<(TypeId, ComponentIsOptional)>, comps_pd: PhantomData, } @@ -216,7 +219,7 @@ where #[derive(Debug)] struct QueryComponentIds { - component_type_ids: Vec, + component_type_ids: Vec<(TypeId, ComponentIsOptional)>, } impl QueryComponentIds @@ -227,10 +230,13 @@ impl QueryComponentIds { let other_component_type_ids = OtherComps::type_ids() .into_iter() + .map(|(type_id, _)| type_id) .collect::>(); self.component_type_ids .iter() - .all(|component_type_id| other_component_type_ids.contains(component_type_id)) + .all(|(component_type_id, _)| { + other_component_type_ids.contains(component_type_id) + }) } } diff --git a/ecs/src/system.rs b/ecs/src/system.rs index 3440a57..3c44148 100644 --- a/ecs/src/system.rs +++ b/ecs/src/system.rs @@ -1,4 +1,4 @@ -use std::any::Any; +use std::any::{type_name, Any}; use std::convert::Infallible; use std::fmt::Debug; use std::marker::PhantomData; @@ -8,7 +8,7 @@ use std::ptr::addr_of; use seq_macro::seq; -use crate::component::Component; +use crate::component::{Component, FromOptionalComponent}; use crate::lock::WriteGuard; use crate::system::util::check_params_are_compatible; use crate::tuple::{FilterElement as TupleFilterElement, With as TupleWith}; @@ -208,6 +208,38 @@ impl<'a, ComponentT: Component> ComponentRefMut<'a, ComponentT> } } +impl<'component, ComponentT: Component> FromOptionalComponent<'component> + for ComponentRefMut<'component, ComponentT> +{ + fn from_optional_component( + inner: Option>>, + ) -> Self + { + Self { + inner: inner.unwrap_or_else(|| { + panic!( + "Component {} was not found in entity", + type_name::() + ); + }), + _ph: PhantomData, + } + } +} + +impl<'comp, ComponentT> FromOptionalComponent<'comp> + for Option> +where + ComponentT: Component, +{ + fn from_optional_component( + optional_component: Option>>, + ) -> Self + { + optional_component.map(|component| ComponentRefMut::new(component)) + } +} + impl<'a, ComponentT: Component> Deref for ComponentRefMut<'a, ComponentT> { type Target = ComponentT; -- cgit v1.2.3-18-g5258