From 9b6611cd11199346cbe1f14ad44930347f90dec2 Mon Sep 17 00:00:00 2001 From: HampusM Date: Sat, 22 Jun 2024 16:15:12 +0200 Subject: feat(ecs): add query options filter entities --- ecs/src/lib.rs | 4 ++- ecs/src/query.rs | 52 ++++++++++++++++++++++++-------------- ecs/src/query/options.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 20 deletions(-) create mode 100644 ecs/src/query/options.rs diff --git a/ecs/src/lib.rs b/ecs/src/lib.rs index b883171..1f745b9 100644 --- a/ecs/src/lib.rs +++ b/ecs/src/lib.rs @@ -15,6 +15,7 @@ use crate::component::{Component, Id as ComponentId, Sequence as ComponentSequen use crate::event::{Event, Id as EventId, Ids, Sequence as EventSequence}; use crate::extension::{Collector as ExtensionCollector, Extension}; use crate::lock::Lock; +use crate::query::options::Options as QueryOptions; use crate::sole::Sole; use crate::system::{System, TypeErased as TypeErasedSystem}; use crate::type_name::TypeName; @@ -121,9 +122,10 @@ impl World drop(event); } - pub fn query(&self) -> Query + pub fn query(&self) -> Query where Comps: ComponentSequence, + OptionsT: QueryOptions, { Query::new(&self.data.component_storage) } diff --git a/ecs/src/query.rs b/ecs/src/query.rs index dc6b5f0..8a63256 100644 --- a/ecs/src/query.rs +++ b/ecs/src/query.rs @@ -1,6 +1,6 @@ use std::any::Any; use std::collections::HashSet; -use std::iter::{Flatten, Map}; +use std::iter::{Filter, Flatten, Map}; use std::marker::PhantomData; use std::sync::{Arc, Weak}; @@ -15,6 +15,7 @@ use crate::component::{ Sequence as ComponentSequence, }; use crate::lock::{Lock, ReadGuard}; +use crate::query::options::Options; use crate::system::{ NoInitParamFlag as NoInitSystemParamFlag, Param as SystemParam, @@ -22,19 +23,22 @@ use crate::system::{ }; use crate::{EntityComponent, WorldData}; +pub mod options; + #[derive(Debug)] -pub struct Query<'world, Comps> +pub struct Query<'world, Comps, OptionsT = ()> where Comps: ComponentSequence, { component_storage: ReadGuard<'world, ComponentStorage>, component_storage_lock: Weak>, - comps_pd: PhantomData, + _pd: PhantomData<(Comps, OptionsT)>, } -impl<'world, Comps> Query<'world, Comps> +impl<'world, Comps, OptionsT> Query<'world, Comps, OptionsT> where Comps: ComponentSequence, + OptionsT: Options, { /// Iterates over the entities matching this query. #[must_use] @@ -51,14 +55,15 @@ where .component_storage .find_entities(&Comps::ids()) .map((|archetype| archetype.components.as_slice()) as ComponentIterMapFn) - .flatten(), + .flatten() + .filter(|components| OptionsT::entity_filter(*components)), comps_pd: PhantomData, } } /// Returns a weak reference query to the same components. #[must_use] - pub fn to_weak_ref(&self) -> WeakRef + pub fn to_weak_ref(&self) -> WeakRef { WeakRef { component_storage: self.component_storage_lock.clone(), @@ -73,14 +78,15 @@ where .read_nonblock() .expect("Failed to acquire read-only component storage lock"), component_storage_lock: Arc::downgrade(component_storage), - comps_pd: PhantomData, + _pd: PhantomData, } } } -impl<'world, Comps> IntoIterator for &'world Query<'world, Comps> +impl<'world, Comps, OptionsT> IntoIterator for &'world Query<'world, Comps, OptionsT> where Comps: ComponentSequence, + OptionsT: Options, { type IntoIter = ComponentIter<'world, Comps>; type Item = Comps::Refs<'world>; @@ -91,9 +97,11 @@ where } } -unsafe impl<'world, Comps> SystemParam<'world> for Query<'world, Comps> +unsafe impl<'world, Comps, OptionsT> SystemParam<'world> + for Query<'world, Comps, OptionsT> where Comps: ComponentSequence, + OptionsT: Options, { type Flags = NoInitSystemParamFlag; type Input = (); @@ -161,15 +169,15 @@ where /// A entity query containing a weak reference to the world. #[derive(Debug)] -pub struct WeakRef +pub struct WeakRef where Comps: ComponentSequence, { component_storage: Weak>, - comps_pd: PhantomData, + comps_pd: PhantomData<(Comps, OptionsT)>, } -impl WeakRef +impl WeakRef where Comps: ComponentSequence, { @@ -177,7 +185,7 @@ where /// /// Returns [`None`] if the [`World`] has been dropped. #[must_use] - pub fn access(&self) -> Option> + pub fn access(&self) -> Option> { Some(Ref { component_storage: self.component_storage.upgrade()?, @@ -186,7 +194,7 @@ where } } -impl Clone for WeakRef +impl Clone for WeakRef where Comps: ComponentSequence, { @@ -202,20 +210,21 @@ where /// Intermediate between [`Query`] and [`WeakRefQuery`]. Contains a strong reference to /// the world which is not allowed direct access to. #[derive(Debug, Clone)] -pub struct Ref<'weak_ref, Comps> +pub struct Ref<'weak_ref, Comps, OptionsT> where Comps: ComponentSequence, { component_storage: Arc>, - _pd: PhantomData<&'weak_ref Comps>, + _pd: PhantomData<(&'weak_ref Comps, OptionsT)>, } -impl<'weak_ref, Comps> Ref<'weak_ref, Comps> +impl<'weak_ref, Comps, OptionsT> Ref<'weak_ref, Comps, OptionsT> where Comps: ComponentSequence, + OptionsT: Options, { #[must_use] - pub fn to_query(&self) -> Query<'_, Comps> + pub fn to_query(&self) -> Query<'_, Comps, OptionsT> { Query::new(&self.component_storage) } @@ -223,9 +232,14 @@ where type ComponentIterMapFn = for<'a> fn(&'a Archetype) -> &'a [Vec]; +type ComponentIterFilterFn = for<'a, 'b> fn(&'a &'b Vec) -> bool; + pub struct ComponentIter<'world, Comps> { - entities: Flatten, ComponentIterMapFn>>, + entities: Filter< + Flatten, ComponentIterMapFn>>, + ComponentIterFilterFn, + >, comps_pd: PhantomData, } diff --git a/ecs/src/query/options.rs b/ecs/src/query/options.rs new file mode 100644 index 0000000..d895073 --- /dev/null +++ b/ecs/src/query/options.rs @@ -0,0 +1,66 @@ +use std::collections::HashSet; +use std::marker::PhantomData; + +use crate::component::{Component, Id as ComponentId}; +use crate::EntityComponent; + +/// Query options. +pub trait Options +{ + fn entity_filter<'component>( + components: impl IntoIterator, + ) -> bool; +} + +impl Options for () +{ + fn entity_filter<'component>( + _: impl IntoIterator, + ) -> bool + { + true + } +} + +pub struct With +where + ComponentT: Component, +{ + _pd: PhantomData, +} + +impl Options for With +where + ComponentT: Component, +{ + fn entity_filter<'component>( + components: impl IntoIterator, + ) -> bool + { + let ids_set = components + .into_iter() + .map(|component| component.id) + .collect::>(); + + ids_set.contains(&ComponentId::of::()) + } +} + +pub struct Not +where + OptionsT: Options, +{ + _pd: PhantomData, +} + +impl Options for Not +where + OptionsT: Options, +{ + fn entity_filter<'component>( + components: impl IntoIterator, + ) -> bool + { + !OptionsT::entity_filter(components) + } +} -- cgit v1.2.3-18-g5258