From e25f4f2af6e5e522f8b6ad0ca0b84f607a73cae2 Mon Sep 17 00:00:00 2001
From: HampusM <hampus@hampusmat.com>
Date: Sun, 16 Jun 2024 17:59:49 +0200
Subject: fix(ecs): prevent archetype creation from causing oob memory accesses

---
 ecs/src/component/storage.rs | 98 ++++++++++++++++++++++++--------------------
 ecs/src/query.rs             |  8 ++--
 2 files changed, 57 insertions(+), 49 deletions(-)

diff --git a/ecs/src/component/storage.rs b/ecs/src/component/storage.rs
index bd53da0..f174ea9 100644
--- a/ecs/src/component/storage.rs
+++ b/ecs/src/component/storage.rs
@@ -1,7 +1,7 @@
 use std::any::type_name;
 use std::collections::{HashMap, HashSet};
 use std::hash::{DefaultHasher, Hash, Hasher};
-use std::ptr::NonNull;
+use std::slice::Iter as SliceIter;
 
 use crate::component::{Component, Id as ComponentId, IsOptional as ComponentIsOptional};
 use crate::lock::Lock;
@@ -12,16 +12,39 @@ use crate::EntityComponent;
 pub struct ComponentStorage
 {
     archetypes: Vec<Archetype>,
-    archetype_lookup: HashMap<ArchetypeComponentsHash, Vec<NonNull<Archetype>>>,
+    archetype_lookup: HashMap<ArchetypeComponentsHash, Vec<usize>>,
     pending_archetype_lookup_entries: Vec<Vec<ComponentId>>,
 }
 
+#[derive(Debug)]
+pub struct ArchetypeRefIter<'component_storage>
+{
+    inner: SliceIter<'component_storage, usize>,
+    archetypes: &'component_storage Vec<Archetype>,
+}
+
+impl<'component_storage> Iterator for ArchetypeRefIter<'component_storage>
+{
+    type Item = &'component_storage Archetype;
+
+    fn next(&mut self) -> Option<Self::Item>
+    {
+        let archetype_index = *self.inner.next()?;
+
+        Some(
+            self.archetypes
+                .get(archetype_index)
+                .expect("Archetype index in archetype lookup entry was not found"),
+        )
+    }
+}
+
 impl ComponentStorage
 {
     pub fn find_entities(
         &self,
         component_ids: &[(ComponentId, ComponentIsOptional)],
-    ) -> Option<&[&Archetype]>
+    ) -> Option<ArchetypeRefIter<'_>>
     {
         let ids = component_ids
             .iter()
@@ -35,12 +58,9 @@ impl ComponentStorage
 
         self.archetype_lookup
             .get(&ArchetypeComponentsHash::new(ids))
-            .map(|archetypes|
-                // SAFETY: All NonNull<Archetype>s are references to items of the
-                // archetypes field and the items won't be dropped until the whole
-                // struct is dropped
-                unsafe {
-                nonnull_slice_to_ref_slice(archetypes.as_slice())
+            .map(|archetypes_indices| ArchetypeRefIter {
+                inner: archetypes_indices.iter(),
+                archetypes: &self.archetypes,
             })
     }
 
@@ -57,7 +77,7 @@ impl ComponentStorage
                 .join(", ")
         );
 
-        let archetypes = self
+        let archetypes_indices = self
             .archetype_lookup
             .entry(ArchetypeComponentsHash::new(
                 components
@@ -68,18 +88,18 @@ impl ComponentStorage
             .or_insert_with(|| {
                 self.archetypes.push(Archetype::default());
 
-                vec![NonNull::from(self.archetypes.last().unwrap())]
+                vec![self.archetypes.len() - 1]
             });
 
-        // SAFETY: All NonNull<Archetype>s are references to items of the
-        // archetypes field and the items won't be dropped until the whole
-        // struct is dropped
-        let archetype = unsafe {
-            archetypes
-                .first_mut()
-                .expect("Archetype has disappeared")
-                .as_mut()
-        };
+        let archetype = self
+            .archetypes
+            .get_mut(
+                archetypes_indices
+                    .first()
+                    .copied()
+                    .expect("No archetype index found"),
+            )
+            .expect("Archetype is gone");
 
         archetype
             .component_ids
@@ -111,22 +131,26 @@ impl ComponentStorage
                 .map(|component_id| *component_id)
                 .collect();
 
-            let matching_archetypes = self.archetypes.iter().filter_map(|archetype| {
-                if archetype.component_ids.is_superset(&components_set) {
-                    return Some(NonNull::from(archetype));
-                }
+            let matching_archetype_indices = self
+                .archetypes
+                .iter()
+                .enumerate()
+                .filter_map(|(index, archetype)| {
+                    if archetype.component_ids.is_superset(&components_set) {
+                        return Some(index);
+                    }
 
-                None
-            });
+                    None
+                });
 
-            let lookup_archetypes = self
+            let archetype_indices = self
                 .archetype_lookup
                 .entry(ArchetypeComponentsHash::new(
                     pending_entry.into_iter().copied().clone(),
                 ))
                 .or_default();
 
-            lookup_archetypes.extend(matching_archetypes);
+            archetype_indices.extend(matching_archetype_indices);
         }
     }
 }
@@ -168,21 +192,10 @@ impl ArchetypeComponentsHash
     }
 }
 
-/// Casts a `&[NonNull<Item>]` to a `&[&Item]`.
-///
-/// # Safety
-/// All items in the slice must be initialized, properly aligned and follow Rust's
-/// aliasing rules.
-const unsafe fn nonnull_slice_to_ref_slice<Item>(slice: &[NonNull<Item>]) -> &[&Item]
-{
-    unsafe { &*(std::ptr::from_ref(slice) as *const [&Item]) }
-}
-
 #[cfg(test)]
 mod tests
 {
     use std::collections::HashSet;
-    use std::ptr::addr_of;
 
     use ecs_macros::Component;
 
@@ -257,14 +270,11 @@ mod tests
             ]))
             .expect("Expected entry in archetype lookup map");
 
-        let archetype_from_lookup = lookup
+        let first_archetype_index = lookup
             .first()
             .expect("Expected archetype lookup to contain a archetype reference");
 
-        assert_eq!(
-            archetype_from_lookup.as_ptr() as usize,
-            addr_of!(*archetype) as usize
-        );
+        assert_eq!(*first_archetype_index, 0);
     }
 
     #[test]
diff --git a/ecs/src/query.rs b/ecs/src/query.rs
index a2edc4d..dcd0b0e 100644
--- a/ecs/src/query.rs
+++ b/ecs/src/query.rs
@@ -2,10 +2,9 @@ use std::any::{type_name, Any};
 use std::collections::HashSet;
 use std::iter::{Flatten, Map};
 use std::marker::PhantomData;
-use std::slice::Iter as SliceIter;
 use std::sync::{Arc, Weak};
 
-use crate::component::storage::Archetype;
+use crate::component::storage::{Archetype, ArchetypeRefIter};
 use crate::component::{
     Id as ComponentId,
     IsOptional as ComponentIsOptional,
@@ -46,7 +45,6 @@ where
                 .component_storage
                 .find_entities(&Comps::ids())
                 .unwrap_or_else(|| panic!("Could not find {:?}", type_name::<Comps>()))
-                .iter()
                 .map((|archetype| archetype.components.as_slice()) as ComponentIterMapFn)
                 .flatten(),
             comps_pd: PhantomData,
@@ -218,11 +216,11 @@ where
     }
 }
 
-type ComponentIterMapFn = for<'a> fn(&'a &'a Archetype) -> &'a [Vec<EntityComponent>];
+type ComponentIterMapFn = for<'a> fn(&'a Archetype) -> &'a [Vec<EntityComponent>];
 
 pub struct ComponentIter<'world, Comps>
 {
-    entities: Flatten<Map<SliceIter<'world, &'world Archetype>, ComponentIterMapFn>>,
+    entities: Flatten<Map<ArchetypeRefIter<'world>, ComponentIterMapFn>>,
     comps_pd: PhantomData<Comps>,
 }
 
-- 
cgit v1.2.3-18-g5258