diff options
Diffstat (limited to 'engine/src/shader.rs')
| -rw-r--r-- | engine/src/shader.rs | 818 |
1 files changed, 818 insertions, 0 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs new file mode 100644 index 0000000..65872f1 --- /dev/null +++ b/engine/src/shader.rs @@ -0,0 +1,818 @@ +use std::any::type_name; +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt::Debug; +use std::path::Path; +use std::str::Utf8Error; + +use bitflags::bitflags; +use ecs::pair::{ChildOf, Pair}; +use ecs::phase::{Phase, START as START_PHASE}; +use ecs::sole::Single; +use ecs::{Component, Sole, declare_entity}; +use shader_slang::{ + Blob as SlangBlob, + ComponentType as SlangComponentType, + DebugInfoLevel as SlangDebugInfoLevel, + EntryPoint as SlangEntryPoint, + GlobalSession as SlangGlobalSession, + Module as SlangModule, + ParameterCategory as SlangParameterCategory, + Session as SlangSession, + TypeKind as SlangTypeKind, +}; + +use crate::asset::{ + Assets, + Event as AssetEvent, + HANDLE_ASSETS_PHASE, + Handle as AssetHandle, + Id as AssetId, + Submitter as AssetSubmitter, +}; +use crate::builder; +use crate::renderer::PRE_RENDER_PHASE; +use crate::shader::default::{ + ASSET_LABEL, + enqueue_set_shader_bindings as default_shader_enqueue_set_shader_bindings, +}; + +pub mod cursor; +pub mod default; + +#[derive(Debug, Clone, Component)] +pub struct Shader +{ + pub asset_handle: AssetHandle<ModuleSource>, +} + +/// Shader module. +#[derive(Debug)] +pub struct ModuleSource +{ + pub name: Cow<'static, str>, + pub file_path: Cow<'static, Path>, + pub source: Cow<'static, str>, + pub link_entrypoints: EntrypointFlags, +} + +bitflags! { + #[derive(Debug, Clone, Copy, Default)] + pub struct EntrypointFlags: usize + { + const FRAGMENT = 1 << 0; + const VERTEX = 1 << 1; + } +} + +pub struct Module +{ + inner: SlangModule, +} + +impl Module +{ + pub fn entry_points(&self) -> impl ExactSizeIterator<Item = EntryPoint> + { + self.inner + .entry_points() + .map(|entry_point| EntryPoint { inner: entry_point }) + } + + pub fn get_entry_point(&self, entry_point: &str) -> Option<EntryPoint> + { + let entry_point = self.inner.find_entry_point_by_name(entry_point)?; + + Some(EntryPoint { inner: entry_point }) + } +} + +pub struct EntryPoint +{ + inner: SlangEntryPoint, +} + +impl EntryPoint +{ + pub fn function_name(&self) -> Option<&str> + { + self.inner.function_reflection().name() + } +} + +pub struct EntryPointReflection<'a> +{ + inner: &'a shader_slang::reflection::EntryPoint, +} + +impl<'a> EntryPointReflection<'a> +{ + pub fn name(&self) -> Option<&str> + { + self.inner.name() + } + + pub fn name_override(&self) -> Option<&str> + { + self.inner.name_override() + } + + pub fn stage(&self) -> Stage + { + Stage::from_slang_stage(self.inner.stage()) + } + + pub fn parameters(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>> + { + self.inner + .parameters() + .map(|param| VariableLayout { inner: param }) + } + + pub fn var_layout(&self) -> Option<VariableLayout<'a>> + { + Some(VariableLayout { inner: self.inner.var_layout()? }) + } +} + +#[derive(Clone)] +pub struct Program +{ + inner: SlangComponentType, +} + +impl Program +{ + pub fn link(&self) -> Result<Program, Error> + { + let linked_program = self.inner.link()?; + + Ok(Program { inner: linked_program }) + } + + pub fn get_entry_point_code(&self, entry_point_index: u32) -> Result<Blob, Error> + { + let blob = self.inner.entry_point_code(entry_point_index.into(), 0)?; + + Ok(Blob { inner: blob }) + } + + pub fn reflection(&self, target: u32) -> Result<ProgramReflection<'_>, Error> + { + let reflection = self.inner.layout(target as i64)?; + + Ok(ProgramReflection { inner: reflection }) + } +} + +impl Debug for Program +{ + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + formatter + .debug_struct(type_name::<Self>()) + .finish_non_exhaustive() + } +} + +pub struct ProgramReflection<'a> +{ + inner: &'a shader_slang::reflection::Shader, +} + +impl<'a> ProgramReflection<'a> +{ + pub fn get_entry_point_by_index(&self, index: u32) + -> Option<EntryPointReflection<'a>> + { + Some(EntryPointReflection { + inner: self.inner.entry_point_by_index(index)?, + }) + } + + pub fn get_entry_point_by_name(&self, name: &str) + -> Option<EntryPointReflection<'a>> + { + Some(EntryPointReflection { + inner: self.inner.find_entry_point_by_name(name)?, + }) + } + + pub fn entry_points(&self) + -> impl ExactSizeIterator<Item = EntryPointReflection<'a>> + { + self.inner + .entry_points() + .map(|entry_point| EntryPointReflection { inner: entry_point }) + } + + pub fn global_params_type_layout(&self) -> Option<TypeLayout<'a>> + { + Some(TypeLayout { + inner: self.inner.global_params_type_layout()?, + }) + } + + pub fn global_params_var_layout(&self) -> Option<VariableLayout<'a>> + { + Some(VariableLayout { + inner: self.inner.global_params_var_layout()?, + }) + } + + pub fn get_type(&self, name: &str) -> Option<TypeReflection<'a>> + { + Some(TypeReflection { + inner: self.inner.find_type_by_name(name)?, + }) + } + + pub fn get_type_layout(&self, ty: &TypeReflection<'a>) -> Option<TypeLayout<'a>> + { + Some(TypeLayout { + inner: self + .inner + .type_layout(&ty.inner, shader_slang::LayoutRules::Default)?, + }) + } +} + +#[derive(Clone, Copy)] +pub struct VariableLayout<'a> +{ + inner: &'a shader_slang::reflection::VariableLayout, +} + +impl<'a> VariableLayout<'a> +{ + pub fn name(&self) -> Option<&str> + { + self.inner.name() + } + + pub fn semantic_name(&self) -> Option<&str> + { + self.inner.semantic_name() + } + + pub fn binding_index(&self) -> u32 + { + self.inner + .offset(shader_slang::ParameterCategory::DescriptorTableSlot) as u32 + + // self.inner.binding_index() + } + + pub fn binding_space(&self) -> u32 + { + self.inner.binding_space() + } + + pub fn semantic_index(&self) -> usize + { + self.inner.semantic_index() + } + + pub fn offset(&self) -> usize + { + self.inner.offset(shader_slang::ParameterCategory::Uniform) + } + + pub fn ty(&self) -> Option<TypeReflection<'a>> + { + self.inner.ty().map(|ty| TypeReflection { inner: ty }) + } + + pub fn type_layout(&self) -> Option<TypeLayout<'a>> + { + Some(TypeLayout { inner: self.inner.type_layout()? }) + } +} + +#[derive(Clone, Copy)] +pub struct TypeLayout<'a> +{ + inner: &'a shader_slang::reflection::TypeLayout, +} + +impl<'a> TypeLayout<'a> +{ + pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>> + { + let index = self.inner.find_field_index_by_name(name); + + if index < 0 { + return None; + } + + let index = u32::try_from(index.cast_unsigned()).expect("Should not happend"); + + let field = self.inner.field_by_index(index)?; + + Some(VariableLayout { inner: field }) + } + + pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64 + { + self.inner.binding_range_descriptor_set_index(index) + } + + pub fn get_field_binding_range_offset_by_name(&self, name: &str) -> Option<u64> + { + let field_index = self.inner.find_field_index_by_name(name); + + if field_index < 0 { + return None; + } + + let field_binding_range_offset = + self.inner.field_binding_range_offset(field_index); + + if field_binding_range_offset < 0 { + return None; + } + + Some(field_binding_range_offset.cast_unsigned()) + } + + pub fn ty(&self) -> Option<TypeReflection<'a>> + { + self.inner.ty().map(|ty| TypeReflection { inner: ty }) + } + + pub fn fields(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>> + { + self.inner + .fields() + .map(|field| VariableLayout { inner: field }) + } + + pub fn field_cnt(&self) -> u32 + { + self.inner.field_count() + } + + pub fn element_type_layout(&self) -> Option<TypeLayout<'a>> + { + self.inner + .element_type_layout() + .map(|type_layout| TypeLayout { inner: type_layout }) + } + + pub fn element_var_layout(&self) -> Option<VariableLayout<'a>> + { + self.inner + .element_var_layout() + .map(|var_layout| VariableLayout { inner: var_layout }) + } + + pub fn container_var_layout(&self) -> Option<VariableLayout<'a>> + { + self.inner + .container_var_layout() + .map(|var_layout| VariableLayout { inner: var_layout }) + } + + pub fn uniform_size(&self) -> Option<usize> + { + // tracing::debug!( + // "uniform_size: {:?} categories: {:?}", + // self.inner.name(), + // self.inner.categories().collect::<Vec<_>>(), + // ); + + if !self + .inner + .categories() + .any(|category| category == SlangParameterCategory::Uniform) + { + return None; + } + + // let category = self.inner.categories().next().unwrap(); + + // println!( + // "AARGH size Category: {category:?} Category count: {}", + // self.inner.category_count() + // ); + + // Some(self.inner.size(category)) + + Some(self.inner.size(SlangParameterCategory::Uniform)) + } + + pub fn stride(&self) -> usize + { + self.inner.stride(self.inner.categories().next().unwrap()) + } +} + +pub struct TypeReflection<'a> +{ + inner: &'a shader_slang::reflection::Type, +} + +impl TypeReflection<'_> +{ + pub fn kind(&self) -> TypeKind + { + TypeKind::from_slang_type_kind(self.inner.kind()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[non_exhaustive] +pub enum TypeKind +{ + None, + Struct, + Enum, + Array, + Matrix, + Vector, + Scalar, + ConstantBuffer, + Resource, + SamplerState, + TextureBuffer, + ShaderStorageBuffer, + ParameterBlock, + GenericTypeParameter, + Interface, + OutputStream, + MeshOutput, + Specialized, + Feedback, + Pointer, + DynamicResource, + Count, +} + +impl TypeKind +{ + fn from_slang_type_kind(type_kind: SlangTypeKind) -> Self + { + match type_kind { + SlangTypeKind::None => Self::None, + SlangTypeKind::Struct => Self::Struct, + SlangTypeKind::Enum => Self::Enum, + SlangTypeKind::Array => Self::Array, + SlangTypeKind::Matrix => Self::Matrix, + SlangTypeKind::Vector => Self::Vector, + SlangTypeKind::Scalar => Self::Scalar, + SlangTypeKind::ConstantBuffer => Self::ConstantBuffer, + SlangTypeKind::Resource => Self::Resource, + SlangTypeKind::SamplerState => Self::SamplerState, + SlangTypeKind::TextureBuffer => Self::TextureBuffer, + SlangTypeKind::ShaderStorageBuffer => Self::ShaderStorageBuffer, + SlangTypeKind::ParameterBlock => Self::ParameterBlock, + SlangTypeKind::GenericTypeParameter => Self::GenericTypeParameter, + SlangTypeKind::Interface => Self::Interface, + SlangTypeKind::OutputStream => Self::OutputStream, + SlangTypeKind::MeshOutput => Self::MeshOutput, + SlangTypeKind::Specialized => Self::Specialized, + SlangTypeKind::Feedback => Self::Feedback, + SlangTypeKind::Pointer => Self::Pointer, + SlangTypeKind::DynamicResource => Self::DynamicResource, + SlangTypeKind::Count => Self::Count, + } + } +} + +pub struct Blob +{ + inner: SlangBlob, +} + +impl Blob +{ + pub fn as_bytes(&self) -> &[u8] + { + self.inner.as_slice() + } + + pub fn as_str(&self) -> Result<&str, Utf8Error> + { + self.inner.as_str() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[non_exhaustive] +pub enum Stage +{ + None, + Vertex, + Hull, + Domain, + Geometry, + Fragment, + Compute, + RayGeneration, + Intersection, + AnyHit, + ClosestHit, + Miss, + Callable, + Mesh, + Amplification, + Dispatch, + Count, +} + +impl Stage +{ + fn from_slang_stage(stage: shader_slang::Stage) -> Self + { + match stage { + shader_slang::Stage::None => Self::None, + shader_slang::Stage::Vertex => Self::Vertex, + shader_slang::Stage::Hull => Self::Hull, + shader_slang::Stage::Domain => Self::Domain, + shader_slang::Stage::Geometry => Self::Geometry, + shader_slang::Stage::Fragment => Self::Fragment, + shader_slang::Stage::Compute => Self::Compute, + shader_slang::Stage::RayGeneration => Self::RayGeneration, + shader_slang::Stage::Intersection => Self::Intersection, + shader_slang::Stage::AnyHit => Self::AnyHit, + shader_slang::Stage::ClosestHit => Self::ClosestHit, + shader_slang::Stage::Miss => Self::Miss, + shader_slang::Stage::Callable => Self::Callable, + shader_slang::Stage::Mesh => Self::Mesh, + shader_slang::Stage::Amplification => Self::Amplification, + shader_slang::Stage::Dispatch => Self::Dispatch, + shader_slang::Stage::Count => Self::Count, + } + } +} + +builder! { +#[builder(name = SettingsBuilder, derives=(Debug))] +#[derive(Debug)] +#[non_exhaustive] +pub struct Settings +{ + link_entrypoints: EntrypointFlags, +} +} + +#[derive(Sole)] +pub struct Context +{ + _global_session: SlangGlobalSession, + session: SlangSession, + modules: HashMap<AssetId, Module>, + programs: HashMap<AssetId, Program>, +} + +impl Context +{ + pub fn get_module(&self, asset_id: &AssetId) -> Option<&Module> + { + self.modules.get(asset_id) + } + + pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program> + { + self.programs.get(asset_id) + } + + pub fn compose_into_program( + &self, + modules: &[&Module], + entry_points: &[&EntryPoint], + ) -> Result<Program, Error> + { + let components = + modules + .iter() + .map(|module| SlangComponentType::from(module.inner.clone())) + .chain(entry_points.iter().map(|entry_point| { + SlangComponentType::from(entry_point.inner.clone()) + })) + .collect::<Vec<_>>(); + + let program = self.session.create_composite_component_type(&components)?; + + Ok(Program { inner: program }) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(#[from] shader_slang::Error); + +pub(crate) fn add_asset_importers(assets: &mut Assets) +{ + assets.set_importer::<_, _>(["slang"], import_slang_asset); +} + +fn import_slang_asset( + asset_submitter: &mut AssetSubmitter<'_>, + file_path: &Path, + settings: Option<&'_ Settings>, +) -> Result<(), ImportError> +{ + let file_name = file_path + .file_name() + .ok_or(ImportError::NoPathFileName)? + .to_str() + .ok_or(ImportError::PathFileNameNotUtf8)?; + + let file_path_canonicalized = file_path + .canonicalize() + .map_err(ImportError::CanonicalizePathFailed)?; + + asset_submitter.submit_store(ModuleSource { + name: file_name.to_owned().into(), + file_path: file_path_canonicalized.into(), + source: std::fs::read_to_string(file_path) + .map_err(ImportError::ReadFileFailed)? + .into(), + link_entrypoints: settings + .map(|settings| settings.link_entrypoints) + .unwrap_or_default(), + }); + + Ok(()) +} + +#[derive(Debug, thiserror::Error)] +enum ImportError +{ + #[error("Failed to read file")] + ReadFileFailed(#[source] std::io::Error), + + #[error("Asset path does not have a file name")] + NoPathFileName, + + #[error("Asset path file name is not valid UTF8")] + PathFileNameNotUtf8, + + #[error("Failed to canonicalize asset path")] + CanonicalizePathFailed(#[source] std::io::Error), +} + +declare_entity!( + IMPORT_SHADERS_PHASE, + ( + Phase, + Pair::builder() + .relation::<ChildOf>() + .target_id(*HANDLE_ASSETS_PHASE) + .build() + ) +); + +pub(crate) struct Extension; + +impl ecs::extension::Extension for Extension +{ + fn collect(self, mut collector: ecs::extension::Collector<'_>) + { + let Some(global_session) = SlangGlobalSession::new() else { + tracing::error!("Unable to create global shader-slang session"); + return; + }; + + let session_options = shader_slang::CompilerOptions::default() + .optimization(shader_slang::OptimizationLevel::None) + .matrix_layout_column(true) + .debug_information(SlangDebugInfoLevel::Maximal) + .no_mangle(true); + + let target_desc = shader_slang::TargetDesc::default() + .format(shader_slang::CompileTarget::Glsl) + // .format(shader_slang::CompileTarget::Spirv) + .profile(global_session.find_profile("glsl_330")); + // .profile(global_session.find_profile("spirv_1_5")); + + let targets = [target_desc]; + + let session_desc = shader_slang::SessionDesc::default() + .targets(&targets) + .search_paths(&[""]) + .options(&session_options); + + let Some(session) = global_session.create_session(&session_desc) else { + tracing::error!("Failed to create shader-slang session"); + return; + }; + + collector + .add_sole(Context { + _global_session: global_session, + session, + modules: HashMap::new(), + programs: HashMap::new(), + }) + .ok(); + + collector.add_declared_entity(&IMPORT_SHADERS_PHASE); + + collector.add_system(*START_PHASE, initialize); + collector.add_system(*IMPORT_SHADERS_PHASE, load_modules); + + collector.add_system( + *PRE_RENDER_PHASE, + default_shader_enqueue_set_shader_bindings, + ); + } +} + +fn initialize(mut assets: Single<Assets>) +{ + assets.store_with_label( + ASSET_LABEL.clone(), + ModuleSource { + name: "default_shader.slang".into(), + file_path: Path::new("@engine/default_shader").into(), + source: include_str!("../res/default_shader.slang").into(), + link_entrypoints: EntrypointFlags::VERTEX | EntrypointFlags::FRAGMENT, + }, + ); +} + +#[tracing::instrument(skip_all)] +fn load_modules(mut context: Single<Context>, assets: Single<Assets>) +{ + for AssetEvent::Stored(asset_id, asset_label) in assets.events().last_tick_events() { + let asset_handle = AssetHandle::<ModuleSource>::from_id(*asset_id); + + if !assets.is_loaded_and_has_type(&asset_handle) { + continue; + } + + let Some(module_source) = assets.get(&asset_handle) else { + unreachable!(); + }; + + tracing::debug!(asset_label=?asset_label, "Loading shader module"); + + let module = match load_module(&context.session, module_source) { + Ok(module) => module, + Err(err) => { + tracing::error!("Failed to load shader module: {err}"); + continue; + } + }; + + if !module_source.link_entrypoints.is_empty() { + assert!(context.programs.get(asset_id).is_none()); + + let Some(vertex_shader_entry_point) = module.get_entry_point("vertex_main") + else { + tracing::error!( + "Shader module does not contain a vertex shader entry point" + ); + continue; + }; + + let Some(fragment_shader_entry_point) = + module.get_entry_point("fragment_main") + else { + tracing::error!( + "Shader module don't contain a fragment_main entry point" + ); + continue; + }; + + let shader_program = match context.compose_into_program( + &[&module], + &[&vertex_shader_entry_point, &fragment_shader_entry_point], + ) { + Ok(shader_program) => shader_program, + Err(err) => { + tracing::error!("Failed to compose shader into program: {err}"); + continue; + } + }; + + let linked_shader_program = match shader_program.link() { + Ok(linked_shader_program) => linked_shader_program, + Err(err) => { + tracing::error!("Failed to link shader: {err}"); + continue; + } + }; + + context.programs.insert(*asset_id, linked_shader_program); + } + + context.modules.insert(*asset_id, module); + } +} + +fn load_module( + session: &SlangSession, + module_source: &ModuleSource, +) -> Result<Module, Error> +{ + let module = session.load_module_from_source_string( + &module_source.name, + &module_source.file_path.to_string_lossy(), + &module_source.source, + )?; + + Ok(Module { inner: module }) +} |
