summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2024-08-25 19:51:00 +0200
committerHampusM <hampus@hampusmat.com>2024-08-25 20:08:30 +0200
commit472215a06849919287b1d3f122c64c8e72532d41 (patch)
treef5de2c4c8cb4ae352772750fb7a2d363d85727d6 /src
parentd27c3b80361e1f0f84576ffcd2e223bb3a505282 (diff)
add implementation base
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs385
1 files changed, 385 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..998da2e
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,385 @@
+use std::alloc::{alloc, handle_alloc_error, Layout};
+use std::marker::PhantomData;
+use std::mem::{forget, needs_drop};
+use std::ptr::NonNull;
+
+/// A list of `ItemT`. This data structure stores a list for every field of `ItemT`,
+/// reducing memory usage if `ItemT` contains padding and improves memory cache usage if
+/// only certain fields are needed when iterating.
+///
+/// Inspired by Zig's `MultiArrayList`.
+///
+/// Note: All of the lists are stored in the same allocation.
+///
+/// For example, if you have three of the following struct:
+/// ```
+/// struct Person
+/// {
+/// first_name: String,
+/// age: u8,
+/// }
+/// ```
+///
+/// It would be stored like this in memory:
+/// ```text
+/// first_name, first_name, first_name,
+/// age, age, age,
+/// ```
+#[derive(Debug)]
+pub struct MultiVec<ItemT>
+where
+ ItemT: Item,
+{
+ _pd: PhantomData<ItemT>,
+ ptr: NonNull<u8>,
+ field_arr_byte_offsets: Vec<usize>,
+ length: usize,
+ layout: Option<Layout>,
+}
+
+impl<ItemT> MultiVec<ItemT>
+where
+ ItemT: Item,
+{
+ pub const fn new() -> Self
+ {
+ Self {
+ _pd: PhantomData,
+ ptr: NonNull::dangling(),
+ field_arr_byte_offsets: Vec::new(),
+ length: 0,
+ layout: None,
+ }
+ }
+
+ pub fn push(&mut self, item: ItemT)
+ {
+ if self.length != 0 {
+ todo!();
+ }
+
+ let (ptr, fields_arr_byte_offsets, layout) = Self::do_first_alloc();
+
+ self.ptr = ptr;
+ self.length = 1;
+ self.field_arr_byte_offsets = fields_arr_byte_offsets;
+ self.layout = Some(layout);
+
+ self.write_item(0, item);
+ }
+
+ pub fn get<FieldSel>(
+ &self,
+ index: usize,
+ ) -> &<FieldSel as ItemFieldSelection<ItemT>>::Field
+ where
+ FieldSel: ItemFieldSelection<ItemT>,
+ {
+ if index >= self.length {
+ panic!(
+ "Index {index} is out of bounds in MultiVec with length {}",
+ self.length
+ );
+ }
+
+ let field_metadata = FieldSel::metadata();
+
+ let field_arr_byte_offset = self.field_arr_byte_offsets[FieldSel::INDEX];
+
+ let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };
+
+ let field_ptr = unsafe { field_arr_ptr.add(field_metadata.size * index) };
+
+ unsafe { field_ptr.cast().as_ref() }
+ }
+
+ fn do_first_alloc() -> (NonNull<u8>, Vec<usize>, Layout)
+ {
+ let (layout, field_arr_byte_offsets) = Self::create_layout(1);
+
+ let Some(ptr) = NonNull::new(unsafe { alloc(layout) }) else {
+ handle_alloc_error(layout);
+ };
+
+ (ptr, field_arr_byte_offsets, layout)
+ }
+
+ fn create_layout(length: usize) -> (Layout, Vec<usize>)
+ {
+ let mut field_metadata_iter = ItemT::iter_field_metadata();
+
+ let first_field_metadata = field_metadata_iter.next().unwrap();
+
+ let mut layout =
+ array_layout(first_field_metadata.size, first_field_metadata.alignment, 1)
+ .unwrap();
+
+ let mut field_arr_byte_offsets = Vec::with_capacity(ItemT::FIELD_CNT);
+
+ field_arr_byte_offsets.push(0);
+
+ for field_metadata in field_metadata_iter {
+ let (new_layout, array_byte_offset) = layout
+ .extend(
+ array_layout(field_metadata.size, field_metadata.alignment, length)
+ .unwrap(),
+ )
+ .unwrap();
+
+ layout = new_layout;
+
+ field_arr_byte_offsets.push(array_byte_offset);
+ }
+
+ (layout, field_arr_byte_offsets)
+ }
+
+ fn write_item(&mut self, index: usize, item: ItemT)
+ {
+ for (field_index, field_metadata) in ItemT::iter_field_metadata().enumerate() {
+ let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];
+
+ let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };
+
+ let field_ptr = unsafe { field_arr_ptr.add(field_metadata.size * index) };
+
+ unsafe {
+ std::ptr::copy_nonoverlapping(
+ (&item as *const ItemT).byte_add(field_metadata.offset) as *const u8,
+ field_ptr.as_ptr(),
+ field_metadata.size,
+ );
+ }
+ }
+
+ forget(item);
+ }
+}
+
+impl<ItemT> Drop for MultiVec<ItemT>
+where
+ ItemT: Item,
+{
+ fn drop(&mut self)
+ {
+ if needs_drop::<ItemT>() {
+ for index in 0..self.length {
+ for (field_index, field_metadata) in
+ ItemT::iter_field_metadata().enumerate()
+ {
+ let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];
+
+ let field_arr_ptr =
+ unsafe { self.ptr.byte_add(field_arr_byte_offset) };
+
+ let field_ptr =
+ unsafe { field_arr_ptr.add(field_metadata.size * index) };
+
+ unsafe {
+ ItemT::drop_field_inplace(field_index, field_ptr.as_ptr());
+ }
+ }
+ }
+ }
+
+ if let Some(layout) = self.layout {
+ unsafe {
+ std::alloc::dealloc(self.ptr.as_ptr(), layout);
+ }
+ }
+ }
+}
+
+/// Usable as a item of a [`MultiVec`].
+///
+/// # Safety
+/// The iterator returned by `iter_field_metadata` must yield [`ItemFieldMetadata`] that
+/// correctly represents fields of the implementor type.
+pub unsafe trait Item
+{
+ type FieldMetadataIter<'a>: Iterator<Item = &'a ItemFieldMetadata>;
+
+ const FIELD_CNT: usize;
+
+ fn iter_field_metadata() -> Self::FieldMetadataIter<'static>;
+
+ unsafe fn drop_field_inplace(field_index: usize, field_ptr: *mut u8);
+}
+
+pub struct ItemFieldMetadata
+{
+ pub offset: usize,
+ pub size: usize,
+ pub alignment: usize,
+}
+
+/// A field selection for `ItemT`.
+///
+/// # Safety
+/// The constant `INDEX`, the type `Field` and the `ItemFieldMetadata` returned by the
+/// `metadata` function must correctly represent a field of `ItemT`;
+pub unsafe trait ItemFieldSelection<ItemT>
+where
+ ItemT: Item,
+{
+ const INDEX: usize;
+
+ type Field;
+
+ fn metadata() -> &'static ItemFieldMetadata;
+}
+
+#[inline]
+const fn array_layout(
+ element_size: usize,
+ align: usize,
+ n: usize,
+) -> Result<Layout, CoolLayoutError>
+{
+ // We need to check two things about the size:
+ // - That the total size won't overflow a `usize`, and
+ // - That the total size still fits in an `isize`.
+ // By using division we can check them both with a single threshold.
+ // That'd usually be a bad idea, but thankfully here the element size
+ // and alignment are constants, so the compiler will fold all of it.
+ if element_size != 0 && n > max_size_for_align(align) / element_size {
+ return Err(CoolLayoutError);
+ }
+
+ // SAFETY: We just checked that we won't overflow `usize` when we multiply.
+ // This is a useless hint inside this function, but after inlining this helps
+ // deduplicate checks for whether the overall capacity is zero (e.g., in RawVec's
+ // allocation path) before/after this multiplication.
+ let array_size = unsafe { element_size.unchecked_mul(n) };
+
+ // SAFETY: We just checked above that the `array_size` will not
+ // exceed `isize::MAX` even when rounded up to the alignment.
+ // And `Alignment` guarantees it's a power of two.
+ unsafe { Ok(Layout::from_size_align_unchecked(array_size, align)) }
+}
+
+#[inline(always)]
+const fn max_size_for_align(align: usize) -> usize
+{
+ // (power-of-two implies align != 0.)
+
+ // Rounded up size is:
+ // size_rounded_up = (size + align - 1) & !(align - 1);
+ //
+ // We know from above that align != 0. If adding (align - 1)
+ // does not overflow, then rounding up will be fine.
+ //
+ // Conversely, &-masking with !(align - 1) will subtract off
+ // only low-order-bits. Thus if overflow occurs with the sum,
+ // the &-mask cannot subtract enough to undo that overflow.
+ //
+ // Above implies that checking for summation overflow is both
+ // necessary and sufficient.
+ isize::MAX as usize - (align - 1)
+}
+
+#[derive(Debug)]
+struct CoolLayoutError;
+
+#[cfg(test)]
+mod tests
+{
+ use std::mem::offset_of;
+ use std::ptr::drop_in_place;
+
+ use crate::{Item, ItemFieldMetadata, ItemFieldSelection, MultiVec};
+
+ struct Foo
+ {
+ num_a: u32,
+ num_b: u16,
+ }
+
+ struct FooFieldNumA;
+ struct FooFieldNumB;
+
+ unsafe impl ItemFieldSelection<Foo> for FooFieldNumA
+ {
+ type Field = u32;
+
+ const INDEX: usize = 0;
+
+ fn metadata() -> &'static ItemFieldMetadata
+ {
+ &FOO_FIELD_METADATA[0]
+ }
+ }
+
+ unsafe impl ItemFieldSelection<Foo> for FooFieldNumB
+ {
+ type Field = u16;
+
+ const INDEX: usize = 1;
+
+ fn metadata() -> &'static ItemFieldMetadata
+ {
+ &FOO_FIELD_METADATA[1]
+ }
+ }
+
+ struct FooFieldMetadataIter<'a>
+ {
+ iter: std::slice::Iter<'a, ItemFieldMetadata>,
+ }
+
+ impl<'a> Iterator for FooFieldMetadataIter<'a>
+ {
+ type Item = &'a ItemFieldMetadata;
+
+ fn next(&mut self) -> Option<Self::Item>
+ {
+ self.iter.next()
+ }
+ }
+
+ static FOO_FIELD_METADATA: [ItemFieldMetadata; 2] = [
+ ItemFieldMetadata {
+ offset: offset_of!(Foo, num_a),
+ size: size_of::<u32>(),
+ alignment: align_of::<u32>(),
+ },
+ ItemFieldMetadata {
+ offset: offset_of!(Foo, num_b),
+ size: size_of::<u16>(),
+ alignment: align_of::<u16>(),
+ },
+ ];
+
+ unsafe impl Item for Foo
+ {
+ type FieldMetadataIter<'a> = FooFieldMetadataIter<'a>;
+
+ const FIELD_CNT: usize = 2;
+
+ fn iter_field_metadata() -> Self::FieldMetadataIter<'static>
+ {
+ FooFieldMetadataIter { iter: FOO_FIELD_METADATA.iter() }
+ }
+
+ unsafe fn drop_field_inplace(field_index: usize, field_ptr: *mut u8)
+ {
+ if field_index == 0 {
+ unsafe { drop_in_place::<u32>(field_ptr.cast()) }
+ } else if field_index == 1 {
+ unsafe { drop_in_place::<u16>(field_ptr.cast()) }
+ }
+ }
+ }
+
+ #[test]
+ fn works()
+ {
+ let mut multi_vec = MultiVec::<Foo>::new();
+
+ multi_vec.push(Foo { num_a: 123, num_b: 654 });
+
+ let item = multi_vec.get::<FooFieldNumB>(0);
+
+ println!("yay: {}", *item);
+ }
+}