#![deny(clippy::all, clippy::pedantic)]
use std::alloc::{alloc, dealloc, handle_alloc_error, realloc, Layout};
use std::any::{Any, TypeId};
use std::cmp::max;
use std::mem::MaybeUninit;
use std::ptr::NonNull;

use crate::util::MaybeUninitByteSlice;

mod util;

pub struct OwnedAnyPtr
{
    ptr: *mut dyn Any,
    drop_in_place: unsafe fn(NonNull<MaybeUninit<u8>>),
}

impl OwnedAnyPtr
{
    pub fn new<Value: Any>(value: Value) -> Self
    {
        Self::from_boxed(Box::new(value))
    }

    pub fn from_boxed<Value: Any>(boxed_value: Box<Value>) -> Self
    {
        Self {
            ptr: Box::into_raw(boxed_value),
            drop_in_place: |ptr| unsafe {
                std::ptr::drop_in_place(ptr.cast::<Value>().as_ptr())
            },
        }
    }

    pub fn as_ptr(&self) -> *const dyn Any
    {
        self.ptr
    }

    pub fn size(&self) -> usize
    {
        size_of_val(unsafe { &*self.ptr })
    }

    pub fn alignment(&self) -> usize
    {
        align_of_val(unsafe { &*self.ptr })
    }

    pub fn id(&self) -> TypeId
    {
        unsafe { &*self.ptr }.type_id()
    }
}

impl Drop for OwnedAnyPtr
{
    fn drop(&mut self)
    {
        if self.size() == 0 {
            return;
        }

        unsafe {
            dealloc(
                self.ptr.cast::<u8>(),
                Layout::from_size_align(self.size(), self.alignment()).unwrap(),
            );
        }
    }
}

/// A list of `ItemT`. This data structure stores a list for every field of `ItemT`,
/// reducing memory usage if `ItemT` contains padding and improves memory cache usage if
/// only certain fields are needed when iterating.
///
/// Inspired by Zig's `MultiArrayList`.
///
/// Note: All of the lists are stored in the same allocation.
///
/// For example, if you have three of the following struct:
/// ```
/// struct Person
/// {
///     first_name: String,
///     age: u8,
/// }
/// ```
///
/// It would be stored like this in memory:
/// ```text
/// first_name, first_name, first_name,
/// age, age, age,
/// ```
#[derive(Debug)]
pub struct MultiVec
{
    ptr: NonNull<MaybeUninit<u8>>,
    field_arr_byte_offsets: Vec<usize>,
    field_metadata: Vec<FieldMetadata>,
    length: usize,
    capacity: usize,
    layout: Option<Layout>,
}

impl MultiVec
{
    fn get_min_non_zero_cap(fields: impl AsRef<[OwnedAnyPtr]>) -> usize
    {
        let total_size = fields
            .as_ref()
            .iter()
            .fold(0usize, |acc, field| acc + field.size());

        // The following is borrow from std's RawVec implementation:
        // Skip to:
        // - 8 if the element size is 1, because any heap allocators is likely to round up
        //   a request of less than 8 bytes to at least 8 bytes.
        // - 4 if elements are moderate-sized (<= 1 KiB).
        // - 1 otherwise, to avoid wasting too much space for very short Vecs.
        if total_size == 1 {
            8
        } else if total_size <= 1024 {
            4
        } else {
            1
        }
    }

    /// Returns a new `MultiVec`. This function does not allocate any memory.
    #[must_use]
    pub const fn new() -> Self
    {
        Self {
            ptr: NonNull::dangling(),
            field_arr_byte_offsets: Vec::new(),
            field_metadata: Vec::new(),
            length: 0,
            capacity: 0,
            layout: None,
        }
    }

    ///// Returns a new `MultiVec` with a capacity for `capacity` items. This function
    ///// will allocate memory.
    //#[must_use]
    //pub fn with_capacity(capacity: usize) -> Self
    //{
    //    let mut this = Self {
    //        _pd: PhantomData,
    //        ptr: NonNull::dangling(),
    //        field_arr_byte_offsets: Vec::new(),
    //        length: 0,
    //        capacity: 0,
    //        layout: None,
    //    };
    //
    //    this.do_first_alloc(capacity);
    //
    //    this
    //}

    /// Pushes a item to the `MultiVec`.
    ///
    /// ## Note on performance
    /// Pushing can be pretty slow. Since all of the field lists are stored in the same
    /// allocation, when pushing and the `MultiVec` needs to grow, all lists except the
    /// first has to be moved to new locations for them to not overlap.
    pub fn push(
        &mut self,
        fields: impl AsRef<[OwnedAnyPtr]> + IntoIterator<Item = OwnedAnyPtr>,
    )
    {
        if self.capacity != 0 {
            assert_eq!(fields.as_ref().len(), self.field_arr_byte_offsets.len());

            if self.capacity == self.length {
                self.grow_amortized(1, &fields);
            }

            self.write_item(self.length, fields);

            self.length += 1;

            return;
        }

        self.field_metadata = fields
            .as_ref()
            .iter()
            .map(|field| FieldMetadata {
                size: field.size(),
                type_id: field.id(),
                drop_in_place: field.drop_in_place,
            })
            .collect();

        self.do_first_alloc(1, &fields);

        self.write_item(0, fields);

        self.length = 1;
    }

    ///// Returns a field of the item with the given index.
    /////
    ///// This function is equivalant to doing `.get_all().get(index)`
    //#[must_use]
    //pub fn get<FieldSel>(
    //    &self,
    //    index: usize,
    //) -> Option<&<FieldSel as ItemFieldSelection<ItemT>>::Field>
    //where
    //    FieldSel: ItemFieldSelection<ItemT>,
    //{
    //    if index >= self.length {
    //        return None;
    //    }
    //
    //    let field_metadata = FieldSel::metadata();
    //
    //    let field_arr_byte_offset = self.field_arr_byte_offsets[FieldSel::INDEX];
    //
    //    let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };
    //
    //    let field_ptr = unsafe { field_arr_ptr.add(field_metadata.size * index) };
    //
    //    Some(unsafe { field_ptr.cast().as_ref() })
    //}

    /// Returns a slice containing the specified field of all items.
    #[must_use]
    pub fn get_field(&self, field_index: usize) -> FieldSlice<'_>
    {
        let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];

        let field_metadata = &self.field_metadata[field_index];

        let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };

        let bytes = unsafe {
            std::slice::from_raw_parts(
                field_arr_ptr.as_ptr().cast(),
                self.len() * field_metadata.size,
            )
        };

        FieldSlice { bytes, field_metadata }
    }

    /// Returns the number of items stored in this `MultiVec`.
    #[must_use]
    pub fn len(&self) -> usize
    {
        self.length
    }

    /// Returns how many items this `MultiVec` has capacity for.
    #[must_use]
    pub fn capacity(&self) -> usize
    {
        self.capacity
    }

    /// Returns whether this `MultiVec` is empty.
    #[must_use]
    pub fn is_empty(&self) -> bool
    {
        self.length == 0
    }

    fn grow_amortized(&mut self, additional: usize, fields: impl AsRef<[OwnedAnyPtr]>)
    {
        let required_cap = self.capacity.checked_add(additional).unwrap();

        // This guarantees exponential growth. The doubling cannot overflow
        // because `cap <= isize::MAX` and the type of `cap` is `usize`.
        let new_capacity = max(self.capacity * 2, required_cap);
        let new_capacity = max(Self::get_min_non_zero_cap(&fields), new_capacity);

        let layout = &self.layout.unwrap();

        let (new_layout, new_field_arr_byte_offsets) =
            Self::create_layout(new_capacity, &fields);

        let Some(new_ptr) = NonNull::new(if layout.size() == 0 {
            std::ptr::dangling_mut()
        } else {
            unsafe { realloc(self.ptr.as_ptr().cast::<u8>(), *layout, new_layout.size()) }
        }) else {
            handle_alloc_error(new_layout);
        };

        for (field_index, field) in fields.as_ref().iter().enumerate().rev() {
            let old_field_arr_byte_offset = self.field_arr_byte_offsets[field_index];
            let new_field_arr_byte_offset = new_field_arr_byte_offsets[field_index];

            let old_field_arr_ptr =
                unsafe { new_ptr.byte_add(old_field_arr_byte_offset) };

            let new_field_arr_ptr =
                unsafe { new_ptr.byte_add(new_field_arr_byte_offset) };

            unsafe {
                std::ptr::copy(
                    old_field_arr_ptr.as_ptr(),
                    new_field_arr_ptr.as_ptr(),
                    field.size() * self.capacity,
                );
            }
        }

        self.ptr = new_ptr.cast::<MaybeUninit<u8>>();
        self.layout = Some(new_layout);
        self.capacity = new_capacity;
        self.field_arr_byte_offsets = new_field_arr_byte_offsets;
    }

    fn do_first_alloc(&mut self, capacity: usize, fields: impl AsRef<[OwnedAnyPtr]>)
    {
        let (layout, field_arr_byte_offsets) = Self::create_layout(capacity, fields);

        let Some(ptr) = NonNull::new(if layout.size() == 0 {
            std::ptr::dangling_mut()
        } else {
            unsafe { alloc(layout) }
        }) else {
            handle_alloc_error(layout);
        };

        self.ptr = ptr.cast::<MaybeUninit<u8>>();
        self.capacity = capacity;
        self.field_arr_byte_offsets = field_arr_byte_offsets;
        self.layout = Some(layout);
    }

    fn create_layout(
        length: usize,
        fields: impl AsRef<[OwnedAnyPtr]>,
    ) -> (Layout, Vec<usize>)
    {
        let mut field_iter = fields.as_ref().iter();

        let first_field = field_iter.next().unwrap();

        let mut layout =
            array_layout(first_field.size(), first_field.alignment(), length).unwrap();

        let mut field_arr_byte_offsets = Vec::with_capacity(fields.as_ref().len());

        field_arr_byte_offsets.push(0);

        for field in field_iter {
            let (new_layout, array_byte_offset) = layout
                .extend(array_layout(field.size(), field.alignment(), length).unwrap())
                .unwrap();

            layout = new_layout;

            field_arr_byte_offsets.push(array_byte_offset);
        }

        (layout, field_arr_byte_offsets)
    }

    fn write_item(&mut self, index: usize, fields: impl IntoIterator<Item = OwnedAnyPtr>)
    {
        for (field_index, item_field) in fields.into_iter().enumerate() {
            let field_size = item_field.size();

            let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];

            let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };

            let field_dst_ptr = unsafe { field_arr_ptr.add(field_size * index) };

            let item_field_ptr = item_field.as_ptr().cast::<u8>();

            unsafe {
                std::ptr::copy_nonoverlapping(
                    item_field_ptr,
                    field_dst_ptr.as_ptr().cast::<u8>(),
                    field_size,
                );
            }
        }
    }
}

//impl<ItemT> FromIterator<ItemT> for MultiVec<ItemT>
//where
//    ItemT: Item,
//{
//    fn from_iter<ItemIter: IntoIterator<Item = ItemT>>(iter: ItemIter) -> Self
//    {
//        let iter = iter.into_iter();
//
//        let initial_capacity =
//            max(Self::MIN_NON_ZERO_CAP, iter.size_hint().0.saturating_add(1));
//
//        let mut this = Self::with_capacity(initial_capacity);
//
//        for item in iter {
//            if this.capacity == this.length {
//                this.grow_amortized(1);
//            }
//
//            this.write_item(this.length, item);
//
//            this.length += 1;
//        }
//
//        this
//    }
//}

impl Default for MultiVec
{
    fn default() -> Self
    {
        Self::new()
    }
}

impl Drop for MultiVec
{
    fn drop(&mut self)
    {
        assert_eq!(self.field_metadata.len(), self.field_arr_byte_offsets.len());

        for index in 0..self.length {
            for (field_index, field_metadata) in self.field_metadata.iter().enumerate() {
                if field_metadata.size == 0 {
                    continue;
                }

                let field_arr_byte_offset = self.field_arr_byte_offsets[field_index];

                let field_arr_ptr = unsafe { self.ptr.byte_add(field_arr_byte_offset) };

                let field_ptr = unsafe { field_arr_ptr.add(field_metadata.size * index) };

                unsafe {
                    (field_metadata.drop_in_place)(field_ptr);
                }
            }
        }

        if let Some(layout) = self.layout {
            if layout.size() == 0 {
                return;
            }

            unsafe {
                std::alloc::dealloc(self.ptr.as_ptr().cast::<u8>(), layout);
            }
        }
    }
}

pub struct FieldSlice<'mv>
{
    bytes: &'mv [MaybeUninit<u8>],
    field_metadata: &'mv FieldMetadata,
}

impl FieldSlice<'_>
{
    pub fn as_slice<Item: 'static>(&self) -> &[Item]
    {
        assert_eq!(TypeId::of::<Item>(), self.field_metadata.type_id);

        unsafe { self.bytes.cast::<Item>() }
    }
}

#[derive(Debug)]
struct FieldMetadata
{
    size: usize,
    type_id: TypeId,
    drop_in_place: unsafe fn(NonNull<MaybeUninit<u8>>),
}

#[inline]
const fn array_layout(
    element_size: usize,
    align: usize,
    n: usize,
) -> Result<Layout, CoolLayoutError>
{
    // We need to check two things about the size:
    //  - That the total size won't overflow a `usize`, and
    //  - That the total size still fits in an `isize`.
    // By using division we can check them both with a single threshold.
    // That'd usually be a bad idea, but thankfully here the element size
    // and alignment are constants, so the compiler will fold all of it.
    if element_size != 0 && n > max_size_for_align(align) / element_size {
        return Err(CoolLayoutError);
    }

    // SAFETY: We just checked that we won't overflow `usize` when we multiply.
    // This is a useless hint inside this function, but after inlining this helps
    // deduplicate checks for whether the overall capacity is zero (e.g., in RawVec's
    // allocation path) before/after this multiplication.
    let array_size = unsafe { element_size.unchecked_mul(n) };

    // SAFETY: We just checked above that the `array_size` will not
    // exceed `isize::MAX` even when rounded up to the alignment.
    // And `Alignment` guarantees it's a power of two.
    unsafe { Ok(Layout::from_size_align_unchecked(array_size, align)) }
}

#[allow(clippy::inline_always)]
#[inline(always)]
const fn max_size_for_align(align: usize) -> usize
{
    // (power-of-two implies align != 0.)

    // Rounded up size is:
    //   size_rounded_up = (size + align - 1) & !(align - 1);
    //
    // We know from above that align != 0. If adding (align - 1)
    // does not overflow, then rounding up will be fine.
    //
    // Conversely, &-masking with !(align - 1) will subtract off
    // only low-order-bits. Thus if overflow occurs with the sum,
    // the &-mask cannot subtract enough to undo that overflow.
    //
    // Above implies that checking for summation overflow is both
    // necessary and sufficient.
    isize::MAX as usize - (align - 1)
}

#[derive(Debug)]
struct CoolLayoutError;

#[cfg(test)]
mod tests
{
    use std::any::TypeId;
    use std::ptr::NonNull;

    use crate::{FieldMetadata, MultiVec, OwnedAnyPtr};

    #[test]
    fn single_push_works()
    {
        let mut multi_vec = MultiVec::new();

        multi_vec.push([OwnedAnyPtr::new(123), OwnedAnyPtr::new(654)]);

        assert_eq!(multi_vec.capacity, 1);
        assert_eq!(multi_vec.length, 1);

        assert_eq!(multi_vec.field_arr_byte_offsets, [0, size_of::<u32>()]);

        assert_eq!(
            unsafe {
                std::slice::from_raw_parts::<u32>(multi_vec.ptr.as_ptr().cast(), 1)
            },
            [123]
        );

        assert_eq!(
            unsafe {
                std::slice::from_raw_parts::<u16>(
                    multi_vec.ptr.as_ptr().byte_add(size_of::<u32>()).cast(),
                    1,
                )
            },
            [654]
        );
    }

    #[test]
    fn multiple_pushes_works()
    {
        let mut multi_vec = MultiVec::new();

        multi_vec.push([OwnedAnyPtr::new(u32::MAX / 2), OwnedAnyPtr::new::<u16>(654)]);
        multi_vec.push([OwnedAnyPtr::new(765), OwnedAnyPtr::new::<u16>(u16::MAX / 3)]);
        multi_vec.push([OwnedAnyPtr::new(u32::MAX / 5), OwnedAnyPtr::new::<u16>(337)]);

        assert_eq!(multi_vec.capacity, 4);
        assert_eq!(multi_vec.length, 3);

        assert_eq!(multi_vec.field_arr_byte_offsets, [0, size_of::<u32>() * 4]);

        assert_eq!(
            unsafe {
                std::slice::from_raw_parts::<u32>(multi_vec.ptr.as_ptr().cast(), 3)
            },
            [u32::MAX / 2, 765, u32::MAX / 5]
        );

        assert_eq!(
            unsafe {
                std::slice::from_raw_parts::<u16>(
                    multi_vec.ptr.as_ptr().byte_add(size_of::<u32>() * 4).cast(),
                    3,
                )
            },
            [654, u16::MAX / 3, 337]
        );
    }

    #[test]
    fn push_all_unsized_fields_work()
    {
        struct UnsizedThing;

        let mut multi_vec = MultiVec::new();

        multi_vec.push([OwnedAnyPtr::new(()), OwnedAnyPtr::new(UnsizedThing)]);
        multi_vec.push([OwnedAnyPtr::new(()), OwnedAnyPtr::new(UnsizedThing)]);
    }

    //#[test]
    //fn multiple_pushes_in_preallocated_works()
    //{
    //    let mut multi_vec = MultiVec::<Foo>::with_capacity(2);
    //
    //    multi_vec.push(Foo { num_a: 83710000, num_b: 654 });
    //    multi_vec.push(Foo { num_a: 765, num_b: u16::MAX / 7 });
    //
    //    assert_eq!(multi_vec.capacity, 2);
    //    assert_eq!(multi_vec.length, 2);
    //
    //    assert_eq!(multi_vec.field_arr_byte_offsets, [0, size_of::<u32>() * 2]);
    //
    //    assert_eq!(
    //        unsafe {
    //            std::slice::from_raw_parts::<u32>(multi_vec.ptr.as_ptr().cast(), 2)
    //        },
    //        [83710000, 765]
    //    );
    //
    //    assert_eq!(
    //        unsafe {
    //            std::slice::from_raw_parts::<u16>(
    //                multi_vec.ptr.as_ptr().byte_add(size_of::<u32>() * 2).cast(),
    //                2,
    //            )
    //        },
    //        [654, u16::MAX / 7]
    //    );
    //}

    //#[test]
    //fn get_works()
    //{
    //    let mut multi_vec = MultiVec::<Foo>::new();
    //
    //    #[repr(packed)]
    //    #[allow(dead_code)]
    //    struct Data
    //    {
    //        num_a: [u32; 3],
    //        num_b: [u16; 3],
    //    }
    //
    //    let data = Data {
    //        num_a: [u32::MAX - 3000, 901, 5560000],
    //        num_b: [20210, 7120, 1010],
    //    };
    //
    //    multi_vec.ptr = NonNull::from(&data).cast();
    //    multi_vec.field_arr_byte_offsets = vec![0, size_of::<u32>() * 3];
    //    multi_vec.length = 3;
    //    multi_vec.capacity = 3;
    //
    //    assert_eq!(
    //        multi_vec.get::<FooFieldNumA>(0).copied(),
    //        Some(u32::MAX - 3000)
    //    );
    //    assert_eq!(multi_vec.get::<FooFieldNumB>(0).copied(), Some(20210));
    //
    //    assert_eq!(multi_vec.get::<FooFieldNumA>(1).copied(), Some(901));
    //    assert_eq!(multi_vec.get::<FooFieldNumB>(1).copied(), Some(7120));
    //
    //    assert_eq!(multi_vec.get::<FooFieldNumA>(2).copied(), Some(5560000));
    //    assert_eq!(multi_vec.get::<FooFieldNumB>(2).copied(), Some(1010));
    //}

    //#[test]
    //fn from_iter_works()
    //{
    //    let multi_vec = MultiVec::<Foo>::from_iter([
    //        Foo { num_a: 456456, num_b: 9090 },
    //        Foo { num_a: 79541, num_b: 2233 },
    //        Foo { num_a: 1761919, num_b: u16::MAX - 75 },
    //        Foo { num_a: u32::MAX / 9, num_b: 8182 },
    //    ]);
    //
    //    assert_eq!(multi_vec.length, 4);
    //    assert_eq!(multi_vec.capacity, 5);
    //
    //    assert_eq!(multi_vec.field_arr_byte_offsets, [0, size_of::<u32>() * 5]);
    //
    //    assert_eq!(
    //        unsafe {
    //            std::slice::from_raw_parts::<u32>(multi_vec.ptr.as_ptr().cast(), 4)
    //        },
    //        [456456, 79541, 1761919, u32::MAX / 9]
    //    );
    //
    //    assert_eq!(
    //        unsafe {
    //            std::slice::from_raw_parts::<u16>(
    //                multi_vec.ptr.as_ptr().byte_add(size_of::<u32>() * 5).cast(),
    //                4,
    //            )
    //        },
    //        [9090, 2233, u16::MAX - 75, 8182]
    //    );
    //}

    #[test]
    fn get_field_works()
    {
        struct Data
        {
            _a: [u32; 3],
            _b: [u16; 3],
        }

        let mut data = Data {
            _a: [u32::MAX - 3000, 901, 5560000],
            _b: [20210, 7120, 1010],
        };

        let mut multi_vec = MultiVec::new();

        multi_vec.ptr = NonNull::from(&mut data).cast();
        multi_vec.field_arr_byte_offsets = vec![0, size_of::<u32>() * 3];
        multi_vec.field_metadata = vec![
            FieldMetadata {
                size: size_of::<u32>(),
                type_id: TypeId::of::<u32>(),
                drop_in_place: |_| {},
            },
            FieldMetadata {
                size: size_of::<u16>(),
                type_id: TypeId::of::<u16>(),
                drop_in_place: |_| {},
            },
        ];
        multi_vec.length = 3;
        multi_vec.capacity = 3;

        assert_eq!(
            multi_vec.get_field(0).as_slice::<u32>(),
            [u32::MAX - 3000, 901, 5560000]
        );

        assert_eq!(
            multi_vec.get_field(1).as_slice::<u16>(),
            [20210, 7120, 1010]
        );
    }
}