summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2025-03-27 21:34:39 +0100
committerHampusM <hampus@hampusmat.com>2025-03-28 11:52:48 +0100
commit714c62e9833f0ebb18d838dd44d6d8ec8d01717f (patch)
treeefde31dcdedd5c8e2d264300fd40b60ee97c42a4 /src
parentfa20f1839448f9d5c7ccb9dcbabeb6d0785f6083 (diff)
refactor: improve drop fn & tests
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs230
-rw-r--r--src/util.rs11
2 files changed, 193 insertions, 48 deletions
diff --git a/src/lib.rs b/src/lib.rs
index dedd1de..0f49968 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -248,6 +248,30 @@ impl MultiVec
FieldSlice { bytes, field_metadata }
}
+ /// Returns a slice containing the specified field of all items.
+ #[must_use]
+ pub fn get_field_mut(&mut self, field_index: usize) -> FieldSliceMut<'_>
+ {
+ let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];
+
+ let field_metadata = &self.field_metadata[field_index];
+
+ let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };
+
+ let bytes = unsafe {
+ std::slice::from_raw_parts_mut(
+ field_arr_ptr.as_ptr().cast(),
+ self.len() * field_metadata.size,
+ )
+ };
+
+ FieldSliceMut {
+ bytes,
+ len: self.len(),
+ field_metadata,
+ }
+ }
+
/// Returns the number of items stored in this `MultiVec`.
#[must_use]
pub fn len(&self) -> usize
@@ -428,20 +452,14 @@ impl Drop for MultiVec
{
assert_eq!(self.field_metadata.len(), self.field_arr_byte_offsets.len());
- for index in 0..self.length {
- for (field_index, field_metadata) in self.field_metadata.iter().enumerate() {
- if field_metadata.size == 0 {
- continue;
- }
-
- let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];
-
- let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };
-
- let field_ptr = unsafe { field_arr_ptr.add(field_metadata.size * index) };
+ for field_index in 0..self.field_arr_byte_offsets.len() {
+ for field in self.get_field_mut(field_index).iter_mut() {
+ let field_ptr = field.bytes.as_mut_ptr();
unsafe {
- (field_metadata.drop_in_place)(field_ptr);
+ (field.field_metadata.drop_in_place)(
+ NonNull::new(field_ptr).unwrap(),
+ );
}
}
}
@@ -474,6 +492,83 @@ impl FieldSlice<'_>
}
}
+pub struct FieldSliceMut<'mv>
+{
+ bytes: &'mv mut [MaybeUninit<u8>],
+ len: usize,
+ field_metadata: &'mv FieldMetadata,
+}
+
+impl FieldSliceMut<'_>
+{
+ pub fn iter_mut(&mut self) -> FieldSliceIterMut<'_>
+ {
+ FieldSliceIterMut {
+ bytes: self.bytes,
+ index: 0,
+ len: self.len,
+ field_metadata: self.field_metadata,
+ }
+ }
+}
+
+pub struct FieldSliceIterMut<'mv>
+{
+ bytes: &'mv mut [MaybeUninit<u8>],
+ index: usize,
+ len: usize,
+ field_metadata: &'mv FieldMetadata,
+}
+
+impl<'mv> Iterator for FieldSliceIterMut<'mv>
+{
+ type Item = FieldMut<'mv>;
+
+ fn next(&mut self) -> Option<Self::Item>
+ {
+ let start_off = self.index * self.field_metadata.size;
+
+ if self.index >= self.len {
+ return None;
+ }
+
+ let field_bytes_a = self
+ .bytes
+ .get_mut(start_off..start_off + self.field_metadata.size)?;
+
+ let field_bytes = unsafe {
+ std::slice::from_raw_parts_mut(
+ field_bytes_a.as_mut_ptr(),
+ //self.bytes.as_mut_ptr().byte_add(start_off),
+ self.field_metadata.size,
+ )
+ };
+
+ self.index += 1;
+
+ Some(FieldMut {
+ bytes: field_bytes,
+ field_metadata: self.field_metadata,
+ })
+ }
+}
+
+pub struct FieldMut<'mv>
+{
+ bytes: &'mv mut [MaybeUninit<u8>],
+ field_metadata: &'mv FieldMetadata,
+}
+
+impl FieldMut<'_>
+{
+ pub fn cast_mut<T: 'static>(&mut self) -> &mut T
+ {
+ assert_eq!(TypeId::of::<T>(), self.field_metadata.type_id);
+
+ unsafe { &mut *self.bytes.as_mut_ptr().cast::<T>() }
+ }
+}
+
#[derive(Debug)]
struct FieldMetadata
{
@@ -539,10 +634,56 @@ struct CoolLayoutError;
mod tests
{
use std::any::TypeId;
+ use std::mem::offset_of;
use std::ptr::NonNull;
use crate::{FieldMetadata, MultiVec, OwnedAnyPtr};
+ macro_rules! multi_vec_with_data {
+ (
+ data = &mut $data: ident,
+ {
+ $($field_name: ident: $field_type: ty = $field_values: expr,)*
+ },
+ length = $length: literal
+ ) => {{
+ #[repr(C)]
+ #[derive(Debug)]
+ struct Data
+ {
+ $($field_name: [$field_type; $length],)*
+ }
+
+ $data = Data {
+ $($field_name: $field_values.map(|val| val.into()),)*
+ };
+
+ let mut multi_vec = MultiVec::new();
+
+ multi_vec.ptr = NonNull::from(&mut $data).cast();
+
+ std::mem::forget($data);
+
+ multi_vec.field_arr_byte_offsets =
+ vec![$(offset_of!(Data, $field_name),)*];
+
+ multi_vec.field_metadata = vec![$(
+ FieldMetadata {
+ size: size_of::<$field_type>(),
+ type_id: TypeId::of::<$field_type>(),
+ drop_in_place: |ptr| unsafe {
+ std::ptr::drop_in_place(ptr.cast::<$field_type>().as_ptr());
+ },
+ },
+ )*];
+
+ multi_vec.length = $length;
+ multi_vec.capacity = multi_vec.length;
+
+ multi_vec
+ }};
+ }
+
#[test]
fn single_push_works()
{
@@ -717,37 +858,17 @@ mod tests
//}
#[test]
- fn get_field_works()
+ fn get_field_works_when_two_fields()
{
- struct Data
- {
- _a: [u32; 3],
- _b: [u16; 3],
- }
-
- let mut data = Data {
- _a: [u32::MAX - 3000, 901, 5560000],
- _b: [20210, 7120, 1010],
- };
-
- let mut multi_vec = MultiVec::new();
-
- multi_vec.ptr = NonNull::from(&mut data).cast();
- multi_vec.field_arr_byte_offsets = vec![0, size_of::<u32>() * 3];
- multi_vec.field_metadata = vec![
- FieldMetadata {
- size: size_of::<u32>(),
- type_id: TypeId::of::<u32>(),
- drop_in_place: |_| {},
- },
- FieldMetadata {
- size: size_of::<u16>(),
- type_id: TypeId::of::<u16>(),
- drop_in_place: |_| {},
+ let mut data;
+ let multi_vec = multi_vec_with_data!(
+ data = &mut data,
+ {
+ _a: u32 = [u32::MAX - 3000, 901, 5560000],
+ _b: u16 = [20210u16, 7120, 1010],
},
- ];
- multi_vec.length = 3;
- multi_vec.capacity = 3;
+ length = 3
+ );
assert_eq!(
multi_vec.get_field(0).as_slice::<u32>(),
@@ -759,4 +880,31 @@ mod tests
[20210, 7120, 1010]
);
}
+
+ #[test]
+ fn get_field_works_when_three_fields()
+ {
+ let mut data;
+ let multi_vec = multi_vec_with_data!(
+ data = &mut data,
+ {
+ _a: u32 = [123u32, 888, 1910, 11144, 770077],
+ _b: String = ["No,", "I", "am", "your", "father"],
+ _c: u8 = [120, 88, 54, 3, 7],
+ },
+ length = 5
+ );
+
+ assert_eq!(
+ multi_vec.get_field(0).as_slice::<u32>(),
+ [123, 888, 1910, 11144, 770077]
+ );
+
+ assert_eq!(
+ multi_vec.get_field(1).as_slice::<String>(),
+ ["No,", "I", "am", "your", "father",]
+ );
+
+ assert_eq!(multi_vec.get_field(2).as_slice::<u8>(), [120, 88, 54, 3, 7]);
+ }
}
diff --git a/src/util.rs b/src/util.rs
index 6b7180a..70e114b 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -17,15 +17,12 @@ impl MaybeUninitByteSlice for &[MaybeUninit<u8>]
"Invalid item alignment"
);
- if size_of::<Item>() == 0 {
+ let new_len = self.len() / size_of::<Item>();
+
+ if new_len == 0 {
return &[];
}
- unsafe {
- std::slice::from_raw_parts(
- self.as_ptr().cast::<Item>(),
- self.len() / size_of::<Item>(),
- )
- }
+ unsafe { std::slice::from_raw_parts(self.as_ptr().cast::<Item>(), new_len) }
}
}