use std::borrow::Cow;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};
use std::string::FromUtf8Error;

const PREINCLUDE_DIRECTIVE: &str = "#preinclude";

pub fn preprocess<'content>(
    shader_content: impl Into<Cow<'content, str>>,
    read_file: &impl Fn(&Path) -> Result<Vec<u8>, std::io::Error>,
) -> Result<Cow<'content, str>, PreprocessingError>
{
    do_preprocess(shader_content, SpanPath::Original, read_file)
}

fn do_preprocess<'content>(
    shader_content: impl Into<Cow<'content, str>>,
    shader_path: SpanPath<'_>,
    read_file: &impl Fn(&Path) -> Result<Vec<u8>, std::io::Error>,
) -> Result<Cow<'content, str>, PreprocessingError>
{
    let shader_content = shader_content.into();

    let mut preincludes = shader_content
        .match_indices(PREINCLUDE_DIRECTIVE)
        .peekable();

    if preincludes.peek().is_none() {
        // Shader content contains no preincludes
        return Ok(shader_content.into());
    };

    let mut preprocessed = shader_content.to_string();

    let mut curr = shader_content.find(PREINCLUDE_DIRECTIVE);

    let mut last_start = 0;
    let mut span_line_offset = 0;

    while let Some(preinclude_start) = curr {
        let replacement_job = handle_preinclude(
            &preprocessed,
            &shader_path,
            preinclude_start,
            span_line_offset,
        )?;

        let path = replacement_job.path.clone();

        let mut included =
            String::from_utf8(read_file(&replacement_job.path).map_err(|err| {
                PreprocessingError::ReadIncludedShaderFailed {
                    source: err,
                    path: replacement_job.path.clone(),
                }
            })?)
            .map_err(|err| {
                PreprocessingError::IncludedShaderInvalidUtf8 {
                    source: err,
                    path: path.clone(),
                }
            })?;

        if let Some(first_line) = included.lines().next() {
            if first_line.starts_with("#version") {
                included = included
                    .chars()
                    .skip_while(|character| *character != '\n')
                    .collect();
            }
        }

        let included_preprocessed = do_preprocess(
            &included,
            SpanPath::Path(replacement_job.path.as_path().into()),
            read_file,
        )?;

        let start = replacement_job.start_index;
        let end = replacement_job.end_index;

        preprocessed.replace_range(start..end, &included_preprocessed);

        curr = preprocessed[last_start + 1..]
            .find(PREINCLUDE_DIRECTIVE)
            .map(|index| index + 1);

        last_start = preinclude_start + included_preprocessed.len();

        span_line_offset += included_preprocessed.lines().count();
    }

    Ok(preprocessed.into())
}

fn handle_preinclude(
    shader_content: &str,
    shader_path: &SpanPath<'_>,
    preinclude_start_index: usize,
    span_line_offset: usize,
) -> Result<ReplacementJob, PreprocessingError>
{
    let expect_token = |token: char, index: usize| {
        let token_found = shader_content.chars().nth(index).ok_or_else(|| {
            PreprocessingError::ExpectedToken {
                expected: token,
                span: Span::new(
                    shader_content,
                    shader_path.to_owned(),
                    index,
                    span_line_offset,
                    preinclude_start_index,
                ),
            }
        })?;

        if token_found != token {
            return Err(PreprocessingError::InvalidToken {
                expected: token,
                found: token_found,
                span: Span::new(
                    shader_content,
                    shader_path.to_owned(),
                    index,
                    span_line_offset,
                    preinclude_start_index,
                ),
            });
        }

        Ok(())
    };

    let space_index = preinclude_start_index + PREINCLUDE_DIRECTIVE.len();
    let quote_open_index = space_index + 1;

    expect_token(' ', space_index)?;
    expect_token('"', quote_open_index)?;

    let buf = shader_content[quote_open_index + 1..]
        .chars()
        .take_while(|character| *character != '"')
        .map(|character| character as u8)
        .collect::<Vec<_>>();

    if buf.is_empty() {
        return Err(PreprocessingError::ExpectedToken {
            expected: '"',
            span: Span::new(
                shader_content,
                shader_path.to_owned(),
                shader_content.len() - 1,
                span_line_offset,
                preinclude_start_index,
            ),
        });
    }

    let path_len = buf.len();

    let path = PathBuf::from(String::from_utf8(buf).map_err(|err| {
        PreprocessingError::PreincludePathInvalidUtf8 {
            source: err,
            span: Span::new(
                shader_content,
                shader_path.to_owned(),
                quote_open_index + 1,
                span_line_offset,
                preinclude_start_index,
            ),
        }
    })?);

    Ok(ReplacementJob {
        start_index: preinclude_start_index,
        end_index: quote_open_index + 1 + path_len + 1,
        path,
    })
}

struct ReplacementJob
{
    start_index: usize,
    end_index: usize,
    path: PathBuf,
}

#[derive(Debug, thiserror::Error)]
pub enum PreprocessingError
{
    #[error(
        "Invalid token at line {}, column {} of {}. Expected '{}', found '{}'",
        span.line,
        span.column,
        span.path,
        expected,
        found
    )]
    InvalidToken
    {
        expected: char,
        found: char,
        span: Span,
    },

    #[error(
        "Expected token '{}' at line {}, column {} of {}. Found eof",
        expected,
        span.line,
        span.column,
        span.path
    )]
    ExpectedToken
    {
        expected: char, span: Span
    },

    #[error(
        "Preinclude path at line {}, column {} of {} is invalid UTF-8",
        span.line,
        span.column,
        span.path
    )]
    PreincludePathInvalidUtf8
    {
        #[source]
        source: FromUtf8Error,
        span: Span,
    },

    #[error("Failed to read included shader")]
    ReadIncludedShaderFailed
    {
        #[source]
        source: std::io::Error,
        path: PathBuf,
    },

    #[error("Included shader is not valid UTF-8")]
    IncludedShaderInvalidUtf8
    {
        #[source]
        source: FromUtf8Error,
        path: PathBuf,
    },
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct Span
{
    pub line: usize,
    pub column: usize,
    pub path: SpanPath<'static>,
}

impl Span
{
    fn new(
        file_content: &str,
        path: SpanPath<'static>,
        char_index: usize,
        line_offset: usize,
        line_start_index: usize,
    ) -> Self
    {
        let line = find_line_of_index(file_content, char_index) + 1
            - line_offset.saturating_sub(1);

        Self {
            line,
            column: char_index - line_start_index + 1,
            path,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpanPath<'a>
{
    Original,
    Path(Cow<'a, Path>),
}

impl<'a> SpanPath<'a>
{
    fn to_owned(&self) -> SpanPath<'static>
    {
        match self {
            Self::Original => SpanPath::Original,
            Self::Path(path) => SpanPath::Path(Cow::Owned(path.to_path_buf().into())),
        }
    }
}

impl<'a> Display for SpanPath<'a>
{
    fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result
    {
        match self {
            Self::Original => write!(formatter, "original file"),
            Self::Path(path) => write!(formatter, "file {}", path.display()),
        }
    }
}

impl<'a, PathLike> PartialEq<PathLike> for SpanPath<'a>
where
    PathLike: AsRef<Path>,
{
    fn eq(&self, other: &PathLike) -> bool
    {
        match self {
            Self::Original => false,
            Self::Path(path) => path == other.as_ref(),
        }
    }
}

fn find_line_of_index(text: &str, index: usize) -> usize
{
    text.chars()
        .take(index + 1)
        .enumerate()
        .filter(|(_, character)| *character == '\n')
        .count()
}

#[cfg(test)]
mod tests
{
    use std::ffi::OsStr;
    use std::path::Path;

    use super::{preprocess, PreprocessingError};
    use crate::opengl::glsl::SpanPath;

    #[test]
    fn preprocess_no_directives_is_same()
    {
        assert_eq!(
            preprocess("#version 330 core\n", &|_| { unreachable!() }).unwrap(),
            "#version 330 core\n"
        );
    }

    #[test]
    fn preprocess_with_directives_works()
    {
        assert_eq!(
            preprocess(
                concat!(
                    "#version 330 core\n",
                    "\n",
                    "#preinclude \"foo.glsl\"\n",
                    "\n",
                    "void main() {}",
                ),
                &|_| { Ok(b"out vec4 FragColor;".to_vec()) }
            )
            .unwrap(),
            concat!(
                "#version 330 core\n",
                "\n",
                "out vec4 FragColor;\n",
                "\n",
                "void main() {}",
            )
        );

        assert_eq!(
            preprocess(
                concat!(
                    "#version 330 core\n",
                    "\n",
                    "#preinclude \"bar.glsl\"\n",
                    "\n",
                    "in vec3 in_frag_color;\n",
                    "\n",
                    "void main() {}",
                ),
                &|_| { Ok(b"out vec4 FragColor;".to_vec()) }
            )
            .unwrap(),
            concat!(
                "#version 330 core\n",
                "\n",
                "out vec4 FragColor;\n",
                "\n",
                "in vec3 in_frag_color;\n",
                "\n",
                "void main() {}",
            )
        );

        assert_eq!(
            preprocess(
                concat!(
                    "#version 330 core\n",
                    "\n",
                    "#preinclude \"bar.glsl\"\n",
                    "\n",
                    "in vec3 in_frag_color;\n",
                    "\n",
                    "#preinclude \"foo.glsl\"\n",
                    "\n",
                    "void main() {}",
                ),
                &|path| {
                    if path == OsStr::new("bar.glsl") {
                        Ok(b"out vec4 FragColor;".to_vec())
                    } else {
                        Ok(concat!(
                            "uniform sampler2D input_texture;\n",
                            "in vec2 in_texture_coords;"
                        )
                        .as_bytes()
                        .to_vec())
                    }
                },
            )
            .unwrap(),
            concat!(
                "#version 330 core\n",
                "\n",
                "out vec4 FragColor;\n",
                "\n",
                "in vec3 in_frag_color;\n",
                "\n",
                "uniform sampler2D input_texture;\n",
                "in vec2 in_texture_coords;\n",
                "\n",
                "void main() {}",
            )
        );
    }

    #[test]
    fn preprocess_invalid_directive_does_not_work()
    {
        let res = preprocess(
            concat!(
                "#version 330 core\n",
                "\n",
                // Missing "
                "#preinclude foo.glsl\"\n",
                "\n",
                "void main() {}",
            ),
            &|_| Ok(b"out vec4 FragColor;".to_vec()),
        );

        let Err(PreprocessingError::InvalidToken { expected, found, span }) = res else {
            panic!(
                "Expected result to be Err(Error::InvalidToken {{ ... }}), is {res:?}"
            );
        };

        assert_eq!(expected, '"');
        assert_eq!(found, 'f');
        assert_eq!(span.line, 3);
        assert_eq!(span.column, 13);
        assert_eq!(span.path, SpanPath::Original);
    }

    #[test]
    fn preprocess_error_has_correct_span()
    {
        let res = preprocess(
            concat!(
                "#version 330 core\n",
                "\n",
                "#preinclude \"bar.glsl\"\n",
                "\n",
                "#preinclude \"foo.glsl\"\n",
                "\n",
                "in vec3 in_frag_color;\n",
                "\n",
                "void main() {}",
            ),
            &|path| {
                if path == OsStr::new("bar.glsl") {
                    Ok(concat!(
                        "out vec4 FragColor;\n",
                        "in vec2 in_texture_coords;\n",
                        "in float foo;"
                    )
                    .as_bytes()
                    .to_vec())
                } else if path == OsStr::new("foo.glsl") {
                    Ok(concat!(
                        "uniform sampler2D input_texture;\n",
                        "\n",
                        // Missing space before first "
                        "#preinclude\"shared_types.glsl\"\n",
                    )
                    .as_bytes()
                    .to_vec())
                } else {
                    panic!(concat!(
                        "Expected read function to be called with ",
                        "either path bar.glsl or foo.glsl"
                    ));
                }
            },
        );

        let Err(PreprocessingError::InvalidToken { expected, found, span }) = res else {
            panic!(
                "Expected result to be Err(Error::InvalidToken {{ ... }}), is {res:?}"
            );
        };

        assert_eq!(expected, ' ');
        assert_eq!(found, '"');
        assert_eq!(span.line, 3);
        assert_eq!(span.column, 12);
        assert_eq!(span.path, SpanPath::Path(Path::new("foo.glsl").into()));
    }

    #[test]
    fn preprocess_included_shader_with_include_works()
    {
        assert_eq!(
            preprocess(
                concat!(
                    "#version 330 core\n",
                    "\n",
                    "#preinclude \"bar.glsl\"\n",
                    "\n",
                    "in vec3 in_frag_color;\n",
                    "\n",
                    "void main() {}",
                ),
                &|path| {
                    if path == OsStr::new("bar.glsl") {
                        Ok(concat!(
                            "#preinclude \"foo.glsl\"\n",
                            "\n",
                            "out vec4 FragColor;"
                        )
                        .as_bytes()
                        .to_vec())
                    } else {
                        Ok(concat!(
                            "uniform sampler2D input_texture;\n",
                            "in vec2 in_texture_coords;"
                        )
                        .as_bytes()
                        .to_vec())
                    }
                }
            )
            .unwrap(),
            concat!(
                "#version 330 core\n",
                "\n",
                "uniform sampler2D input_texture;\n",
                "in vec2 in_texture_coords;\n",
                "\n",
                "out vec4 FragColor;\n",
                "\n",
                "in vec3 in_frag_color;\n",
                "\n",
                "void main() {}",
            )
        );
    }

    #[test]
    fn preprocess_included_shader_with_include_error_span_is_correct()
    {
        let res = preprocess(
            concat!(
                "#version 330 core\n",
                "\n",
                "#preinclude \"bar.glsl\"\n",
                "\n",
                "in vec3 in_frag_color;\n",
                "\n",
                "void main() {}",
            ),
            &|path| {
                if path == OsStr::new("bar.glsl") {
                    Ok(concat!(
                        // ' instead of "
                        "#preinclude 'foo.glsl\"\n",
                        "\n",
                        "out vec4 FragColor;"
                    )
                    .as_bytes()
                    .to_vec())
                } else {
                    Ok(concat!(
                        "uniform sampler2D input_texture;\n",
                        "in vec2 in_texture_coords;"
                    )
                    .as_bytes()
                    .to_vec())
                }
            },
        );

        let Err(PreprocessingError::InvalidToken { expected, found, span }) = res else {
            panic!(
                "Expected result to be Err(Error::InvalidToken {{ ... }}), is {res:?}"
            );
        };

        assert_eq!(expected, '"');
        assert_eq!(found, '\'');
        assert_eq!(span.line, 1);
        assert_eq!(span.column, 13);
        assert_eq!(span.path, Path::new("bar.glsl"));
    }
}