//! OpenGL renderer.

use std::collections::HashMap;
use std::ffi::{c_void, CString};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::ops::Deref;
use std::path::Path;
use std::process::abort;

use ecs::actions::Actions;
use ecs::component::local::Local;
use ecs::phase::{PRESENT as PRESENT_PHASE, START as START_PHASE};
use ecs::query::options::{Not, With};
use ecs::sole::Single;
use ecs::system::{Into as _, System};
use ecs::{Component, Query};

use crate::camera::{Active as ActiveCamera, Camera};
use crate::color::Color;
use crate::data_types::dimens::Dimens;
use crate::draw_flags::{DrawFlags, NoDraw, PolygonModeConfig};
use crate::lighting::{DirectionalLight, GlobalLight, PointLight};
use crate::material::{Flags as MaterialFlags, Material};
use crate::matrix::Matrix;
use crate::mesh::Mesh;
use crate::opengl::buffer::{Buffer, Usage as BufferUsage};
use crate::opengl::debug::{
    enable_debug_output,
    set_debug_message_callback,
    set_debug_message_control,
    MessageIdsAction,
    MessageSeverity,
    MessageSource,
    MessageType,
};
use crate::opengl::glsl::{
    preprocess as glsl_preprocess,
    PreprocessingError as GlslPreprocessingError,
};
use crate::opengl::shader::{
    Error as GlShaderError,
    Kind as ShaderKind,
    Program as GlShaderProgram,
    Shader as GlShader,
};
use crate::opengl::texture::{
    set_active_texture_unit,
    Texture as GlTexture,
    TextureUnit,
};
use crate::opengl::vertex_array::{
    DataType as VertexArrayDataType,
    PrimitiveKind,
    VertexArray,
};
use crate::opengl::{
    clear_buffers,
    enable,
    get_context_flags as get_opengl_context_flags,
    BufferClearMask,
    Capability,
    ContextFlags,
};
use crate::projection::{ClipVolume, Projection};
use crate::texture::{Id as TextureId, Texture};
use crate::transform::{Position, Scale};
use crate::util::{defer, Defer, RefOrValue};
use crate::vector::{Vec2, Vec3};
use crate::vertex::{AttributeComponentType, Vertex};
use crate::window::Window;

type RenderableEntity<'a> = (
    &'a Mesh,
    &'a Material,
    &'a Option<MaterialFlags>,
    &'a Option<Position>,
    &'a Option<Scale>,
    &'a Option<DrawFlags>,
    &'a Option<GlObjects>,
);

#[derive(Debug, Default)]
#[non_exhaustive]
pub struct Extension {}

impl ecs::extension::Extension for Extension
{
    fn collect(self, mut collector: ecs::extension::Collector<'_>)
    {
        collector.add_system(*START_PHASE, initialize);

        collector.add_system(
            *PRESENT_PHASE,
            render
                .into_system()
                .initialize((GlobalGlObjects::default(),)),
        );
    }
}

fn initialize(window: Single<Window>)
{
    window
        .make_context_current()
        .expect("Failed to make window context current");

    gl::load_with(|symbol| match window.get_proc_address(symbol) {
        Ok(addr) => addr as *const c_void,
        Err(err) => {
            println!(
                "FATAL ERROR: Failed to get adress of OpenGL function {symbol}: {err}",
            );

            abort();
        }
    });

    if get_opengl_context_flags().contains(ContextFlags::DEBUG) {
        initialize_debug();
    }

    let window_size = window.size().expect("Failed to get window size");

    set_viewport(Vec2 { x: 0, y: 0 }, window_size);

    window.set_framebuffer_size_callback(|new_window_size| {
        set_viewport(Vec2::ZERO, new_window_size);
    });

    enable(Capability::DepthTest);
    enable(Capability::MultiSample);
}

#[allow(clippy::too_many_arguments)]
fn render(
    query: Query<RenderableEntity<'_>, Not<With<NoDraw>>>,
    point_light_query: Query<(&PointLight,)>,
    directional_lights: Query<(&DirectionalLight,)>,
    camera_query: Query<(&Camera, &Position, &ActiveCamera)>,
    window: Single<Window>,
    global_light: Single<GlobalLight>,
    mut gl_objects: Local<GlobalGlObjects>,
    mut actions: Actions,
)
{
    let Some((camera, camera_pos, _)) = camera_query.iter().next() else {
        tracing::warn!("No current camera. Nothing will be rendered");
        return;
    };

    let point_lights = point_light_query
        .iter()
        .map(|(point_light,)| point_light)
        .collect::<Vec<_>>();

    let directional_lights = directional_lights.iter().collect::<Vec<_>>();

    let GlobalGlObjects {
        shader_program,
        textures: gl_textures,
    } = &mut *gl_objects;

    let shader_program =
        shader_program.get_or_insert_with(|| create_default_shader_program().unwrap());

    clear_buffers(BufferClearMask::COLOR | BufferClearMask::DEPTH);

    for (
        euid,
        (mesh, material, material_flags, position, scale, draw_flags, gl_objects),
    ) in query.iter_with_euids()
    {
        let material_flags = material_flags
            .map(|material_flags| material_flags.clone())
            .unwrap_or_default();

        let gl_objs = match gl_objects.as_deref() {
            Some(gl_objs) => RefOrValue::Ref(gl_objs),
            None => RefOrValue::Value(Some(GlObjects::new(&mesh))),
        };

        defer!(|gl_objs| {
            if let RefOrValue::Value(opt_gl_objs) = gl_objs {
                actions.add_components(euid, (opt_gl_objs.take().unwrap(),));
            };
        });

        apply_transformation_matrices(
            Transformation {
                position: position.map(|pos| *pos).unwrap_or_default().position,
                scale: scale.map(|scale| *scale).unwrap_or_default().scale,
            },
            shader_program,
            &camera,
            &camera_pos,
            window.size().expect("Failed to get window size"),
        );

        apply_light(
            &material,
            &material_flags,
            &global_light,
            shader_program,
            point_lights.as_slice(),
            directional_lights
                .iter()
                .map(|(dir_light,)| &**dir_light)
                .collect::<Vec<_>>()
                .as_slice(),
            &camera_pos,
        );

        for (index, texture) in material.textures.iter().enumerate() {
            let gl_texture = gl_textures
                .entry(texture.id())
                .or_insert_with(|| create_gl_texture(texture));

            let texture_unit = TextureUnit::from_num(index).expect("Too many textures");

            set_active_texture_unit(texture_unit);

            gl_texture.bind();
        }

        shader_program.activate();

        if let Some(draw_flags) = &draw_flags {
            crate::opengl::set_polygon_mode(
                draw_flags.polygon_mode_config.face,
                draw_flags.polygon_mode_config.mode,
            );
        }

        draw_mesh(gl_objs.get().unwrap());

        if draw_flags.is_some() {
            let default_polygon_mode_config = PolygonModeConfig::default();

            crate::opengl::set_polygon_mode(
                default_polygon_mode_config.face,
                default_polygon_mode_config.mode,
            );
        }
    }
}

#[derive(Debug, Default, Component)]
struct GlobalGlObjects
{
    shader_program: Option<GlShaderProgram>,
    textures: HashMap<TextureId, GlTexture>,
}

fn set_viewport(position: Vec2<u32>, size: Dimens<u32>)
{
    crate::opengl::set_viewport(position, size);
}

fn initialize_debug()
{
    enable_debug_output();

    set_debug_message_callback(opengl_debug_message_cb);
    set_debug_message_control(None, None, None, &[], MessageIdsAction::Disable);
}

fn draw_mesh(gl_objects: &GlObjects)
{
    gl_objects.vertex_arr.bind();

    if gl_objects.index_buffer.is_some() {
        VertexArray::draw_elements(PrimitiveKind::Triangles, 0, gl_objects.element_cnt);
    } else {
        VertexArray::draw_arrays(PrimitiveKind::Triangles, 0, gl_objects.element_cnt);
    }
}

fn create_gl_texture(texture: &Texture) -> GlTexture
{
    let mut gl_texture = GlTexture::new();

    gl_texture.generate(
        *texture.dimensions(),
        texture.image().as_bytes(),
        texture.pixel_data_format(),
    );

    gl_texture.apply_properties(texture.properties());

    gl_texture
}

const VERTEX_GLSL_SHADER_SRC: &str = include_str!("opengl/glsl/vertex.glsl");
const FRAGMENT_GLSL_SHADER_SRC: &str = include_str!("opengl/glsl/fragment.glsl");

const VERTEX_DATA_GLSL_SHADER_SRC: &str = include_str!("opengl/glsl/vertex_data.glsl");
const LIGHT_GLSL_SHADER_SRC: &str = include_str!("opengl/glsl/light.glsl");

fn create_default_shader_program() -> Result<GlShaderProgram, CreateShaderError>
{
    let mut vertex_shader = GlShader::new(ShaderKind::Vertex);

    vertex_shader.set_source(&*glsl_preprocess(
        VERTEX_GLSL_SHADER_SRC,
        &get_glsl_shader_content,
    )?)?;

    vertex_shader.compile()?;

    let mut fragment_shader = GlShader::new(ShaderKind::Fragment);

    fragment_shader.set_source(&*glsl_preprocess(
        FRAGMENT_GLSL_SHADER_SRC,
        &get_glsl_shader_content,
    )?)?;

    fragment_shader.compile()?;

    let mut gl_shader_program = GlShaderProgram::new();

    gl_shader_program.attach(&vertex_shader);
    gl_shader_program.attach(&fragment_shader);

    gl_shader_program.link()?;

    Ok(gl_shader_program)
}

#[derive(Debug, thiserror::Error)]
enum CreateShaderError
{
    #[error(transparent)]
    ShaderError(#[from] GlShaderError),

    #[error(transparent)]
    PreprocessingError(#[from] GlslPreprocessingError),
}

fn get_glsl_shader_content(path: &Path) -> Result<Vec<u8>, std::io::Error>
{
    if path == Path::new("vertex_data.glsl") {
        return Ok(VERTEX_DATA_GLSL_SHADER_SRC.as_bytes().to_vec());
    }

    if path == Path::new("light.glsl") {
        return Ok(LIGHT_GLSL_SHADER_SRC.as_bytes().to_vec());
    }

    Err(IoError::new(
        IoErrorKind::NotFound,
        format!("Content for shader file {} not found", path.display()),
    ))
}

#[derive(Debug, Component)]
struct GlObjects
{
    /// Vertex and index buffer has to live as long as the vertex array
    _vertex_buffer: Buffer<Vertex>,
    index_buffer: Option<Buffer<u32>>,
    element_cnt: u32,

    vertex_arr: VertexArray,
}

impl GlObjects
{
    #[tracing::instrument(skip_all)]
    fn new(mesh: &Mesh) -> Self
    {
        tracing::trace!(
            "Creating vertex array, vertex buffer{}",
            if mesh.indices().is_some() {
                " and index buffer"
            } else {
                ""
            }
        );

        let mut vertex_arr = VertexArray::new();
        let mut vertex_buffer = Buffer::new();

        vertex_buffer.store(mesh.vertices(), BufferUsage::Static);

        vertex_arr.bind_vertex_buffer(0, &vertex_buffer, 0);

        let mut offset = 0u32;

        for attrib in Vertex::attrs() {
            vertex_arr.enable_attrib(attrib.index);

            vertex_arr.set_attrib_format(
                attrib.index,
                match attrib.component_type {
                    AttributeComponentType::Float => VertexArrayDataType::Float,
                },
                false,
                offset,
            );

            vertex_arr.set_attrib_vertex_buf_binding(attrib.index, 0);

            offset += attrib.component_size * attrib.component_cnt as u32;
        }

        if let Some(indices) = mesh.indices() {
            let mut index_buffer = Buffer::new();

            index_buffer.store(indices, BufferUsage::Static);

            vertex_arr.bind_element_buffer(&index_buffer);

            return Self {
                _vertex_buffer: vertex_buffer,
                index_buffer: Some(index_buffer),
                element_cnt: indices
                    .len()
                    .try_into()
                    .expect("Mesh index count does not fit into a 32-bit unsigned int"),
                vertex_arr,
            };
        }

        Self {
            _vertex_buffer: vertex_buffer,
            index_buffer: None,
            element_cnt: mesh
                .vertices()
                .len()
                .try_into()
                .expect("Mesh vertex count does not fit into a 32-bit unsigned int"),
            vertex_arr,
        }
    }
}

fn apply_transformation_matrices(
    transformation: Transformation,
    gl_shader_program: &mut GlShaderProgram,
    camera: &Camera,
    camera_pos: &Position,
    window_size: Dimens<u32>,
)
{
    gl_shader_program
        .set_uniform_matrix_4fv(c"model", &create_transformation_matrix(transformation));

    let view_matrix = create_view_matrix(camera, &camera_pos.position);

    gl_shader_program.set_uniform_matrix_4fv(c"view", &view_matrix);

    #[allow(clippy::cast_precision_loss)]
    let proj_matrix = match &camera.projection {
        Projection::Perspective(perspective_proj) => perspective_proj.to_matrix_rh(
            window_size.width as f32 / window_size.height as f32,
            ClipVolume::NegOneToOne,
        ),
        Projection::Orthographic(orthographic_proj) => {
            orthographic_proj.to_matrix_rh(&camera_pos.position, ClipVolume::NegOneToOne)
        }
    };

    gl_shader_program.set_uniform_matrix_4fv(c"projection", &proj_matrix);
}

fn apply_light<PointLightHolder>(
    material: &Material,
    material_flags: &MaterialFlags,
    global_light: &GlobalLight,
    gl_shader_program: &mut GlShaderProgram,
    point_lights: &[PointLightHolder],
    directional_lights: &[&DirectionalLight],
    camera_pos: &Position,
) where
    PointLightHolder: Deref<Target = PointLight>,
{
    debug_assert!(
        point_lights.len() < 64,
        "Shader cannot handle more than 64 point lights"
    );

    debug_assert!(
        directional_lights.len() < 64,
        "Shader cannot handle more than 64 directional lights"
    );

    for (dir_light_index, dir_light) in directional_lights.iter().enumerate() {
        gl_shader_program.set_uniform_vec_3fv(
            &create_light_uniform_name(
                "directional_lights",
                dir_light_index,
                "direction",
            ),
            &dir_light.direction,
        );

        set_light_phong_uniforms(
            gl_shader_program,
            "directional_lights",
            dir_light_index,
            *dir_light,
        );
    }

    // There probably won't be more than 2147483648 directional lights
    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
    gl_shader_program
        .set_uniform_1i(c"directional_light_cnt", directional_lights.len() as i32);

    for (point_light_index, point_light) in point_lights.iter().enumerate() {
        gl_shader_program.set_uniform_vec_3fv(
            &create_light_uniform_name("point_lights", point_light_index, "position"),
            &point_light.position,
        );

        set_light_phong_uniforms(
            gl_shader_program,
            "point_lights",
            point_light_index,
            &**point_light,
        );

        set_light_attenuation_uniforms(
            gl_shader_program,
            "point_lights",
            point_light_index,
            point_light,
        );
    }

    // There probably won't be more than 2147483648 point lights
    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
    gl_shader_program.set_uniform_1i(c"point_light_cnt", point_lights.len() as i32);

    gl_shader_program.set_uniform_vec_3fv(
        c"material.ambient",
        &if material_flags.use_ambient_color {
            material.ambient.clone()
        } else {
            global_light.ambient.clone()
        }
        .into(),
    );

    gl_shader_program
        .set_uniform_vec_3fv(c"material.diffuse", &material.diffuse.clone().into());

    #[allow(clippy::cast_possible_wrap)]
    gl_shader_program
        .set_uniform_vec_3fv(c"material.specular", &material.specular.clone().into());

    let texture_map = material
        .textures
        .iter()
        .enumerate()
        .map(|(index, texture)| (texture.id(), index))
        .collect::<HashMap<_, _>>();

    #[allow(clippy::cast_possible_wrap)]
    gl_shader_program.set_uniform_1i(
        c"material.ambient_map",
        *texture_map.get(&material.ambient_map).unwrap() as i32,
    );

    #[allow(clippy::cast_possible_wrap)]
    gl_shader_program.set_uniform_1i(
        c"material.diffuse_map",
        *texture_map.get(&material.diffuse_map).unwrap() as i32,
    );

    #[allow(clippy::cast_possible_wrap)]
    gl_shader_program.set_uniform_1i(
        c"material.specular_map",
        *texture_map.get(&material.specular_map).unwrap() as i32,
    );

    gl_shader_program.set_uniform_1fv(c"material.shininess", material.shininess);

    gl_shader_program.set_uniform_vec_3fv(c"view_pos", &camera_pos.position);
}

fn set_light_attenuation_uniforms(
    gl_shader_program: &mut GlShaderProgram,
    light_array: &str,
    light_index: usize,
    light: &PointLight,
)
{
    gl_shader_program.set_uniform_1fv(
        &create_light_uniform_name(
            light_array,
            light_index,
            "attenuation_props.constant",
        ),
        light.attenuation_params.constant,
    );

    gl_shader_program.set_uniform_1fv(
        &create_light_uniform_name(light_array, light_index, "attenuation_props.linear"),
        light.attenuation_params.linear,
    );

    gl_shader_program.set_uniform_1fv(
        &create_light_uniform_name(
            light_array,
            light_index,
            "attenuation_props.quadratic",
        ),
        light.attenuation_params.quadratic,
    );
}

fn set_light_phong_uniforms(
    gl_shader_program: &mut GlShaderProgram,
    light_array: &str,
    light_index: usize,
    light: &impl Light,
)
{
    gl_shader_program.set_uniform_vec_3fv(
        &create_light_uniform_name(light_array, light_index, "phong.diffuse"),
        &light.diffuse().clone().into(),
    );

    gl_shader_program.set_uniform_vec_3fv(
        &create_light_uniform_name(light_array, light_index, "phong.specular"),
        &light.specular().clone().into(),
    );
}

trait Light
{
    fn diffuse(&self) -> &Color<f32>;
    fn specular(&self) -> &Color<f32>;
}

impl Light for PointLight
{
    fn diffuse(&self) -> &Color<f32>
    {
        &self.diffuse
    }

    fn specular(&self) -> &Color<f32>
    {
        &self.specular
    }
}

impl Light for DirectionalLight
{
    fn diffuse(&self) -> &Color<f32>
    {
        &self.diffuse
    }

    fn specular(&self) -> &Color<f32>
    {
        &self.specular
    }
}

fn create_light_uniform_name(
    light_array: &str,
    light_index: usize,
    light_field: &str,
) -> CString
{
    unsafe {
        CString::from_vec_with_nul_unchecked(
            format!("{light_array}[{light_index}].{light_field}\0").into(),
        )
    }
}

fn create_view_matrix(camera: &Camera, camera_pos: &Vec3<f32>) -> Matrix<f32, 4, 4>
{
    let mut view = Matrix::new();

    view.look_at(&camera_pos, &camera.target, &camera.global_up);

    view
}

#[tracing::instrument(skip_all)]
fn opengl_debug_message_cb(
    source: MessageSource,
    ty: MessageType,
    id: u32,
    severity: MessageSeverity,
    message: &str,
)
{
    use std::backtrace::{Backtrace, BacktraceStatus};

    use tracing::{event, Level};

    macro_rules! create_event {
        ($level: expr) => {
            event!($level, ?source, ?ty, id, ?severity, message);
        };
    }

    if matches!(severity, MessageSeverity::Notification) {
        return;
    }

    match ty {
        MessageType::Error => {
            create_event!(Level::ERROR);

            let backtrace = Backtrace::capture();

            if matches!(backtrace.status(), BacktraceStatus::Captured) {
                event!(Level::TRACE, "{backtrace}");
            }
        }
        MessageType::Other => {
            create_event!(Level::INFO);
        }
        _ => {
            create_event!(Level::WARN);
        }
    };
}

#[derive(Debug)]
struct Transformation
{
    position: Vec3<f32>,
    scale: Vec3<f32>,
}

fn create_transformation_matrix(transformation: Transformation) -> Matrix<f32, 4, 4>
{
    let mut matrix = Matrix::new_identity();

    matrix.translate(&transformation.position);
    matrix.scale(&transformation.scale);

    matrix
}