diff options
Diffstat (limited to 'engine/src/shader.rs')
| -rw-r--r-- | engine/src/shader.rs | 378 |
1 files changed, 340 insertions, 38 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs index 65872f1..ae0be1d 100644 --- a/engine/src/shader.rs +++ b/engine/src/shader.rs @@ -1,3 +1,4 @@ +use std::alloc::Layout; use std::any::type_name; use std::borrow::Cow; use std::collections::HashMap; @@ -5,7 +6,7 @@ use std::fmt::Debug; use std::path::Path; use std::str::Utf8Error; -use bitflags::bitflags; +use bitflags::{bitflags, bitflags_match}; use ecs::pair::{ChildOf, Pair}; use ecs::phase::{Phase, START as START_PHASE}; use ecs::sole::Single; @@ -31,6 +32,12 @@ use crate::asset::{ Submitter as AssetSubmitter, }; use crate::builder; +use crate::mesh::Vertex; +use crate::reflection::{ + Struct as StructReflection, + StructField as StructFieldReflection, + With, +}; use crate::renderer::PRE_RENDER_PHASE; use crate::shader::default::{ ASSET_LABEL, @@ -40,6 +47,10 @@ use crate::shader::default::{ pub mod cursor; pub mod default; +/// The vertex parameter of a vertex entrypoint function in a shader should have this +/// semantic name. +pub const VERTEX_PARAM_SEMANTIC_NAME: &str = "VERTEX"; + #[derive(Debug, Clone, Component)] pub struct Shader { @@ -57,7 +68,7 @@ pub struct ModuleSource } bitflags! { - #[derive(Debug, Clone, Copy, Default)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct EntrypointFlags: usize { const FRAGMENT = 1 << 0; @@ -65,6 +76,7 @@ bitflags! { } } +#[derive(Clone)] pub struct Module { inner: SlangModule, @@ -94,9 +106,24 @@ pub struct EntryPoint impl EntryPoint { - pub fn function_name(&self) -> Option<&str> + pub fn function(&self) -> FunctionReflection<'_> + { + FunctionReflection { + inner: self.inner.function_reflection(), + } + } +} + +pub struct FunctionReflection<'a> +{ + inner: &'a shader_slang::reflection::Function, +} + +impl<'a> FunctionReflection<'a> +{ + pub fn name(&self) -> Option<&str> { - self.inner.function_reflection().name() + self.inner.name() } } @@ -245,7 +272,7 @@ pub struct VariableLayout<'a> impl<'a> VariableLayout<'a> { - pub fn name(&self) -> Option<&str> + pub fn name(&self) -> Option<&'a str> { self.inner.name() } @@ -263,6 +290,19 @@ impl<'a> VariableLayout<'a> // self.inner.binding_index() } + pub fn varying_input_offset(&self) -> Option<usize> + { + if !self + .inner + .categories() + .any(|category| category == SlangParameterCategory::VaryingInput) + { + return None; + } + + Some(self.inner.offset(SlangParameterCategory::VaryingInput)) + } + pub fn binding_space(&self) -> u32 { self.inner.binding_space() @@ -297,6 +337,11 @@ pub struct TypeLayout<'a> impl<'a> TypeLayout<'a> { + pub fn kind(&self) -> TypeKind + { + TypeKind::from_slang_type_kind(self.inner.kind()) + } + pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>> { let index = self.inner.find_field_index_by_name(name); @@ -312,6 +357,11 @@ impl<'a> TypeLayout<'a> Some(VariableLayout { inner: field }) } + pub fn parameter_category(&self) -> ParameterCategory + { + ParameterCategory::from_slang_parameter_category(self.inner.parameter_category()) + } + pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64 { self.inner.binding_range_descriptor_set_index(index) @@ -479,6 +529,80 @@ impl TypeKind } } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum ParameterCategory +{ + None, + Mixed, + ConstantBuffer, + ShaderResource, + UnorderedAccess, + VaryingInput, + VaryingOutput, + SamplerState, + Uniform, + DescriptorTableSlot, + SpecializationConstant, + PushConstantBuffer, + RegisterSpace, + Generic, + RayPayload, + HitAttributes, + CallablePayload, + ShaderRecord, + ExistentialTypeParam, + ExistentialObjectParam, + SubElementRegisterSpace, + Subpass, + MetalArgumentBufferElement, + MetalAttribute, + MetalPayload, + Count, +} + +impl ParameterCategory +{ + fn from_slang_parameter_category(parameter_category: SlangParameterCategory) -> Self + { + match parameter_category { + SlangParameterCategory::None => Self::None, + SlangParameterCategory::Mixed => Self::Mixed, + SlangParameterCategory::ConstantBuffer => Self::ConstantBuffer, + SlangParameterCategory::ShaderResource => Self::ShaderResource, + SlangParameterCategory::UnorderedAccess => Self::UnorderedAccess, + SlangParameterCategory::VaryingInput => Self::VaryingInput, + SlangParameterCategory::VaryingOutput => Self::VaryingOutput, + SlangParameterCategory::SamplerState => Self::SamplerState, + SlangParameterCategory::Uniform => Self::Uniform, + SlangParameterCategory::DescriptorTableSlot => Self::DescriptorTableSlot, + SlangParameterCategory::SpecializationConstant => { + Self::SpecializationConstant + } + SlangParameterCategory::PushConstantBuffer => Self::PushConstantBuffer, + SlangParameterCategory::RegisterSpace => Self::RegisterSpace, + SlangParameterCategory::Generic => Self::Generic, + SlangParameterCategory::RayPayload => Self::RayPayload, + SlangParameterCategory::HitAttributes => Self::HitAttributes, + SlangParameterCategory::CallablePayload => Self::CallablePayload, + SlangParameterCategory::ShaderRecord => Self::ShaderRecord, + SlangParameterCategory::ExistentialTypeParam => Self::ExistentialTypeParam, + SlangParameterCategory::ExistentialObjectParam => { + Self::ExistentialObjectParam + } + SlangParameterCategory::SubElementRegisterSpace => { + Self::SubElementRegisterSpace + } + SlangParameterCategory::Subpass => Self::Subpass, + SlangParameterCategory::MetalArgumentBufferElement => { + Self::MetalArgumentBufferElement + } + SlangParameterCategory::MetalAttribute => Self::MetalAttribute, + SlangParameterCategory::MetalPayload => Self::MetalPayload, + SlangParameterCategory::Count => Self::Count, + } + } +} + pub struct Blob { inner: SlangBlob, @@ -562,7 +686,7 @@ pub struct Context _global_session: SlangGlobalSession, session: SlangSession, modules: HashMap<AssetId, Module>, - programs: HashMap<AssetId, Program>, + programs: HashMap<AssetId, (Program, ProgramMetadata)>, } impl Context @@ -574,20 +698,27 @@ impl Context pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program> { - self.programs.get(asset_id) + self.programs.get(asset_id).map(|(program, _)| program) + } + + pub fn get_program_metadata(&self, asset_id: &AssetId) -> Option<&ProgramMetadata> + { + self.programs + .get(asset_id) + .map(|(_, program_metadata)| program_metadata) } pub fn compose_into_program( &self, - modules: &[&Module], - entry_points: &[&EntryPoint], + modules: impl IntoIterator<Item = Module>, + entry_points: impl IntoIterator<Item = EntryPoint>, ) -> Result<Program, Error> { let components = modules - .iter() + .into_iter() .map(|module| SlangComponentType::from(module.inner.clone())) - .chain(entry_points.iter().map(|entry_point| { + .chain(entry_points.into_iter().map(|entry_point| { SlangComponentType::from(entry_point.inner.clone()) })) .collect::<Vec<_>>(); @@ -598,6 +729,138 @@ impl Context } } +pub struct ProgramMetadata +{ + pub vertex_subset: Option<VertexSubset>, +} + +#[derive(Debug)] +pub struct VertexSubset +{ + pub layout: Layout, + pub fields: [Option<VertexSubsetField>; const { + Vertex::REFLECTION.as_struct().unwrap().fields.len() + }], +} + +impl VertexSubset +{ + pub fn new( + vs_entrypoint: &EntryPointReflection<'_>, + ) -> Result<Self, VertexSubsetError> + { + const VERTEX_REFLECTION: &StructReflection = + const { Vertex::REFLECTION.as_struct().unwrap() }; + + if vs_entrypoint.stage() != Stage::Vertex { + return Err(VertexSubsetError::EntrypointNotInVertexStage); + } + + let vs_entrypoint_vertex_param = vs_entrypoint + .parameters() + .find(|param| param.semantic_name() == Some(VERTEX_PARAM_SEMANTIC_NAME)) + .ok_or(VertexSubsetError::EntrypointMissingVertexParam)?; + + let vs_entrypoint_vertex_param = vs_entrypoint_vertex_param + .type_layout() + .expect("Not possible"); + + if vs_entrypoint_vertex_param.parameter_category() + != ParameterCategory::VaryingInput + { + return Err(VertexSubsetError::EntryPointVertexParamNotVaryingInput); + } + + if vs_entrypoint_vertex_param.kind() != TypeKind::Struct { + return Err(VertexSubsetError::EntrypointVertexTypeNotStruct); + } + + if let Some(unknown_vertex_field_name) = vs_entrypoint_vertex_param + .fields() + .find_map(|vertex_param_field| { + let vertex_param_field_name = + vertex_param_field.name().expect("Not possible"); + + if VERTEX_REFLECTION + .fields + .iter() + .all(|vertex_field| vertex_field.name != vertex_param_field_name) + { + return Some(vertex_param_field_name); + } + + None + }) + { + return Err(VertexSubsetError::EntrypointVertexTypeHasUnknownField { + field_name: unknown_vertex_field_name.to_string(), + }); + } + + let mut layout = Layout::new::<()>(); + + let mut fields = [const { None }; const { VERTEX_REFLECTION.fields.len() }]; + + for vertex_field in const { VERTEX_REFLECTION.fields } { + let Some(vertex_field_var_layout) = + vs_entrypoint_vertex_param.get_field_by_name(vertex_field.name) + else { + continue; + }; + + let (new_layout, vertex_field_offset) = + layout.extend(vertex_field.layout).expect("Not possible"); + + layout = new_layout; + + fields[vertex_field.index] = Some(VertexSubsetField { + offset: vertex_field_offset, + reflection: vertex_field, + varying_input_offset: vertex_field_var_layout + .varying_input_offset() + .expect("Not possible"), + }); + } + + layout = layout.pad_to_align(); + + Ok(Self { layout, fields }) + } +} + +#[derive(Debug)] +pub struct VertexSubsetField +{ + pub offset: usize, + pub reflection: &'static StructFieldReflection, + pub varying_input_offset: usize, +} + +#[derive(Debug, thiserror::Error)] +pub enum VertexSubsetError +{ + #[error("Entrypoint is not in vertex stage")] + EntrypointNotInVertexStage, + + #[error( + "Entrypoint does not have a vertex parameter (parameter with semantic name {})", + VERTEX_PARAM_SEMANTIC_NAME + )] + EntrypointMissingVertexParam, + + #[error("Entrypoint vertex parameter is not a varying input")] + EntryPointVertexParamNotVaryingInput, + + #[error("Entrypoint vertex type is not a struct")] + EntrypointVertexTypeNotStruct, + + #[error("Entrypoint vertex type has unknown field {field_name}")] + EntrypointVertexTypeHasUnknownField + { + field_name: String + }, +} + #[derive(Debug, thiserror::Error)] #[error(transparent)] pub struct Error(#[from] shader_slang::Error); @@ -757,37 +1020,48 @@ fn load_modules(mut context: Single<Context>, assets: Single<Assets>) } }; + context.modules.insert(*asset_id, module.clone()); + if !module_source.link_entrypoints.is_empty() { assert!(context.programs.get(asset_id).is_none()); - let Some(vertex_shader_entry_point) = module.get_entry_point("vertex_main") - else { - tracing::error!( - "Shader module does not contain a vertex shader entry point" - ); - continue; - }; - - let Some(fragment_shader_entry_point) = - module.get_entry_point("fragment_main") - else { - tracing::error!( - "Shader module don't contain a fragment_main entry point" - ); - continue; - }; - - let shader_program = match context.compose_into_program( - &[&module], - &[&vertex_shader_entry_point, &fragment_shader_entry_point], - ) { - Ok(shader_program) => shader_program, - Err(err) => { - tracing::error!("Failed to compose shader into program: {err}"); + let entry_points = match module_source + .link_entrypoints + .iter() + .filter_map(|entrypoint_flag| { + let entrypoint_name = bitflags_match!(entrypoint_flag, { + EntrypointFlags::VERTEX => Some("vertex_main"), + EntrypointFlags::FRAGMENT => Some("fragment_main"), + _ => None + })?; + + let Some(entry_point) = module.get_entry_point(entrypoint_name) + else { + return Some(Err(EntrypointNotFoundError { entrypoint_name })); + }; + + Some(Ok(entry_point)) + }) + .collect::<Result<Vec<_>, EntrypointNotFoundError>>() + { + Ok(entry_points) => entry_points, + Err(EntrypointNotFoundError { entrypoint_name }) => { + tracing::error!( + "Shader module does not have a '{entrypoint_name}' entry point" + ); continue; } }; + let shader_program = + match context.compose_into_program([module], entry_points) { + Ok(shader_program) => shader_program, + Err(err) => { + tracing::error!("Failed to compose shader into program: {err}"); + continue; + } + }; + let linked_shader_program = match shader_program.link() { Ok(linked_shader_program) => linked_shader_program, Err(err) => { @@ -796,10 +1070,32 @@ fn load_modules(mut context: Single<Context>, assets: Single<Assets>) } }; - context.programs.insert(*asset_id, linked_shader_program); - } + let vertex_subset = if module_source + .link_entrypoints + .contains(EntrypointFlags::VERTEX) + { + VertexSubset::new( + &shader_program + .reflection(0) + .expect("Not possible") + .get_entry_point_by_name("vertex_main") + .expect("Not possible"), + ) + .inspect_err(|err| { + tracing::error!( + "Failed to create vertex subset for shader {asset_label:?}: {err}" + ); + }) + .ok() + } else { + None + }; - context.modules.insert(*asset_id, module); + context.programs.insert( + *asset_id, + (linked_shader_program, ProgramMetadata { vertex_subset }), + ); + } } } @@ -816,3 +1112,9 @@ fn load_module( Ok(Module { inner: module }) } + +#[derive(Debug)] +struct EntrypointNotFoundError +{ + entrypoint_name: &'static str, +} |
