summaryrefslogtreecommitdiff
path: root/engine/src/shader.rs
diff options
context:
space:
mode:
Diffstat (limited to 'engine/src/shader.rs')
-rw-r--r--engine/src/shader.rs378
1 files changed, 340 insertions, 38 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs
index 65872f1..ae0be1d 100644
--- a/engine/src/shader.rs
+++ b/engine/src/shader.rs
@@ -1,3 +1,4 @@
+use std::alloc::Layout;
use std::any::type_name;
use std::borrow::Cow;
use std::collections::HashMap;
@@ -5,7 +6,7 @@ use std::fmt::Debug;
use std::path::Path;
use std::str::Utf8Error;
-use bitflags::bitflags;
+use bitflags::{bitflags, bitflags_match};
use ecs::pair::{ChildOf, Pair};
use ecs::phase::{Phase, START as START_PHASE};
use ecs::sole::Single;
@@ -31,6 +32,12 @@ use crate::asset::{
Submitter as AssetSubmitter,
};
use crate::builder;
+use crate::mesh::Vertex;
+use crate::reflection::{
+ Struct as StructReflection,
+ StructField as StructFieldReflection,
+ With,
+};
use crate::renderer::PRE_RENDER_PHASE;
use crate::shader::default::{
ASSET_LABEL,
@@ -40,6 +47,10 @@ use crate::shader::default::{
pub mod cursor;
pub mod default;
+/// The vertex parameter of a vertex entrypoint function in a shader should have this
+/// semantic name.
+pub const VERTEX_PARAM_SEMANTIC_NAME: &str = "VERTEX";
+
#[derive(Debug, Clone, Component)]
pub struct Shader
{
@@ -57,7 +68,7 @@ pub struct ModuleSource
}
bitflags! {
- #[derive(Debug, Clone, Copy, Default)]
+ #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct EntrypointFlags: usize
{
const FRAGMENT = 1 << 0;
@@ -65,6 +76,7 @@ bitflags! {
}
}
+#[derive(Clone)]
pub struct Module
{
inner: SlangModule,
@@ -94,9 +106,24 @@ pub struct EntryPoint
impl EntryPoint
{
- pub fn function_name(&self) -> Option<&str>
+ pub fn function(&self) -> FunctionReflection<'_>
+ {
+ FunctionReflection {
+ inner: self.inner.function_reflection(),
+ }
+ }
+}
+
+pub struct FunctionReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Function,
+}
+
+impl<'a> FunctionReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
{
- self.inner.function_reflection().name()
+ self.inner.name()
}
}
@@ -245,7 +272,7 @@ pub struct VariableLayout<'a>
impl<'a> VariableLayout<'a>
{
- pub fn name(&self) -> Option<&str>
+ pub fn name(&self) -> Option<&'a str>
{
self.inner.name()
}
@@ -263,6 +290,19 @@ impl<'a> VariableLayout<'a>
// self.inner.binding_index()
}
+ pub fn varying_input_offset(&self) -> Option<usize>
+ {
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::VaryingInput)
+ {
+ return None;
+ }
+
+ Some(self.inner.offset(SlangParameterCategory::VaryingInput))
+ }
+
pub fn binding_space(&self) -> u32
{
self.inner.binding_space()
@@ -297,6 +337,11 @@ pub struct TypeLayout<'a>
impl<'a> TypeLayout<'a>
{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+
pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>>
{
let index = self.inner.find_field_index_by_name(name);
@@ -312,6 +357,11 @@ impl<'a> TypeLayout<'a>
Some(VariableLayout { inner: field })
}
+ pub fn parameter_category(&self) -> ParameterCategory
+ {
+ ParameterCategory::from_slang_parameter_category(self.inner.parameter_category())
+ }
+
pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64
{
self.inner.binding_range_descriptor_set_index(index)
@@ -479,6 +529,80 @@ impl TypeKind
}
}
+#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
+pub enum ParameterCategory
+{
+ None,
+ Mixed,
+ ConstantBuffer,
+ ShaderResource,
+ UnorderedAccess,
+ VaryingInput,
+ VaryingOutput,
+ SamplerState,
+ Uniform,
+ DescriptorTableSlot,
+ SpecializationConstant,
+ PushConstantBuffer,
+ RegisterSpace,
+ Generic,
+ RayPayload,
+ HitAttributes,
+ CallablePayload,
+ ShaderRecord,
+ ExistentialTypeParam,
+ ExistentialObjectParam,
+ SubElementRegisterSpace,
+ Subpass,
+ MetalArgumentBufferElement,
+ MetalAttribute,
+ MetalPayload,
+ Count,
+}
+
+impl ParameterCategory
+{
+ fn from_slang_parameter_category(parameter_category: SlangParameterCategory) -> Self
+ {
+ match parameter_category {
+ SlangParameterCategory::None => Self::None,
+ SlangParameterCategory::Mixed => Self::Mixed,
+ SlangParameterCategory::ConstantBuffer => Self::ConstantBuffer,
+ SlangParameterCategory::ShaderResource => Self::ShaderResource,
+ SlangParameterCategory::UnorderedAccess => Self::UnorderedAccess,
+ SlangParameterCategory::VaryingInput => Self::VaryingInput,
+ SlangParameterCategory::VaryingOutput => Self::VaryingOutput,
+ SlangParameterCategory::SamplerState => Self::SamplerState,
+ SlangParameterCategory::Uniform => Self::Uniform,
+ SlangParameterCategory::DescriptorTableSlot => Self::DescriptorTableSlot,
+ SlangParameterCategory::SpecializationConstant => {
+ Self::SpecializationConstant
+ }
+ SlangParameterCategory::PushConstantBuffer => Self::PushConstantBuffer,
+ SlangParameterCategory::RegisterSpace => Self::RegisterSpace,
+ SlangParameterCategory::Generic => Self::Generic,
+ SlangParameterCategory::RayPayload => Self::RayPayload,
+ SlangParameterCategory::HitAttributes => Self::HitAttributes,
+ SlangParameterCategory::CallablePayload => Self::CallablePayload,
+ SlangParameterCategory::ShaderRecord => Self::ShaderRecord,
+ SlangParameterCategory::ExistentialTypeParam => Self::ExistentialTypeParam,
+ SlangParameterCategory::ExistentialObjectParam => {
+ Self::ExistentialObjectParam
+ }
+ SlangParameterCategory::SubElementRegisterSpace => {
+ Self::SubElementRegisterSpace
+ }
+ SlangParameterCategory::Subpass => Self::Subpass,
+ SlangParameterCategory::MetalArgumentBufferElement => {
+ Self::MetalArgumentBufferElement
+ }
+ SlangParameterCategory::MetalAttribute => Self::MetalAttribute,
+ SlangParameterCategory::MetalPayload => Self::MetalPayload,
+ SlangParameterCategory::Count => Self::Count,
+ }
+ }
+}
+
pub struct Blob
{
inner: SlangBlob,
@@ -562,7 +686,7 @@ pub struct Context
_global_session: SlangGlobalSession,
session: SlangSession,
modules: HashMap<AssetId, Module>,
- programs: HashMap<AssetId, Program>,
+ programs: HashMap<AssetId, (Program, ProgramMetadata)>,
}
impl Context
@@ -574,20 +698,27 @@ impl Context
pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program>
{
- self.programs.get(asset_id)
+ self.programs.get(asset_id).map(|(program, _)| program)
+ }
+
+ pub fn get_program_metadata(&self, asset_id: &AssetId) -> Option<&ProgramMetadata>
+ {
+ self.programs
+ .get(asset_id)
+ .map(|(_, program_metadata)| program_metadata)
}
pub fn compose_into_program(
&self,
- modules: &[&Module],
- entry_points: &[&EntryPoint],
+ modules: impl IntoIterator<Item = Module>,
+ entry_points: impl IntoIterator<Item = EntryPoint>,
) -> Result<Program, Error>
{
let components =
modules
- .iter()
+ .into_iter()
.map(|module| SlangComponentType::from(module.inner.clone()))
- .chain(entry_points.iter().map(|entry_point| {
+ .chain(entry_points.into_iter().map(|entry_point| {
SlangComponentType::from(entry_point.inner.clone())
}))
.collect::<Vec<_>>();
@@ -598,6 +729,138 @@ impl Context
}
}
+pub struct ProgramMetadata
+{
+ pub vertex_subset: Option<VertexSubset>,
+}
+
+#[derive(Debug)]
+pub struct VertexSubset
+{
+ pub layout: Layout,
+ pub fields: [Option<VertexSubsetField>; const {
+ Vertex::REFLECTION.as_struct().unwrap().fields.len()
+ }],
+}
+
+impl VertexSubset
+{
+ pub fn new(
+ vs_entrypoint: &EntryPointReflection<'_>,
+ ) -> Result<Self, VertexSubsetError>
+ {
+ const VERTEX_REFLECTION: &StructReflection =
+ const { Vertex::REFLECTION.as_struct().unwrap() };
+
+ if vs_entrypoint.stage() != Stage::Vertex {
+ return Err(VertexSubsetError::EntrypointNotInVertexStage);
+ }
+
+ let vs_entrypoint_vertex_param = vs_entrypoint
+ .parameters()
+ .find(|param| param.semantic_name() == Some(VERTEX_PARAM_SEMANTIC_NAME))
+ .ok_or(VertexSubsetError::EntrypointMissingVertexParam)?;
+
+ let vs_entrypoint_vertex_param = vs_entrypoint_vertex_param
+ .type_layout()
+ .expect("Not possible");
+
+ if vs_entrypoint_vertex_param.parameter_category()
+ != ParameterCategory::VaryingInput
+ {
+ return Err(VertexSubsetError::EntryPointVertexParamNotVaryingInput);
+ }
+
+ if vs_entrypoint_vertex_param.kind() != TypeKind::Struct {
+ return Err(VertexSubsetError::EntrypointVertexTypeNotStruct);
+ }
+
+ if let Some(unknown_vertex_field_name) = vs_entrypoint_vertex_param
+ .fields()
+ .find_map(|vertex_param_field| {
+ let vertex_param_field_name =
+ vertex_param_field.name().expect("Not possible");
+
+ if VERTEX_REFLECTION
+ .fields
+ .iter()
+ .all(|vertex_field| vertex_field.name != vertex_param_field_name)
+ {
+ return Some(vertex_param_field_name);
+ }
+
+ None
+ })
+ {
+ return Err(VertexSubsetError::EntrypointVertexTypeHasUnknownField {
+ field_name: unknown_vertex_field_name.to_string(),
+ });
+ }
+
+ let mut layout = Layout::new::<()>();
+
+ let mut fields = [const { None }; const { VERTEX_REFLECTION.fields.len() }];
+
+ for vertex_field in const { VERTEX_REFLECTION.fields } {
+ let Some(vertex_field_var_layout) =
+ vs_entrypoint_vertex_param.get_field_by_name(vertex_field.name)
+ else {
+ continue;
+ };
+
+ let (new_layout, vertex_field_offset) =
+ layout.extend(vertex_field.layout).expect("Not possible");
+
+ layout = new_layout;
+
+ fields[vertex_field.index] = Some(VertexSubsetField {
+ offset: vertex_field_offset,
+ reflection: vertex_field,
+ varying_input_offset: vertex_field_var_layout
+ .varying_input_offset()
+ .expect("Not possible"),
+ });
+ }
+
+ layout = layout.pad_to_align();
+
+ Ok(Self { layout, fields })
+ }
+}
+
+#[derive(Debug)]
+pub struct VertexSubsetField
+{
+ pub offset: usize,
+ pub reflection: &'static StructFieldReflection,
+ pub varying_input_offset: usize,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum VertexSubsetError
+{
+ #[error("Entrypoint is not in vertex stage")]
+ EntrypointNotInVertexStage,
+
+ #[error(
+ "Entrypoint does not have a vertex parameter (parameter with semantic name {})",
+ VERTEX_PARAM_SEMANTIC_NAME
+ )]
+ EntrypointMissingVertexParam,
+
+ #[error("Entrypoint vertex parameter is not a varying input")]
+ EntryPointVertexParamNotVaryingInput,
+
+ #[error("Entrypoint vertex type is not a struct")]
+ EntrypointVertexTypeNotStruct,
+
+ #[error("Entrypoint vertex type has unknown field {field_name}")]
+ EntrypointVertexTypeHasUnknownField
+ {
+ field_name: String
+ },
+}
+
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct Error(#[from] shader_slang::Error);
@@ -757,37 +1020,48 @@ fn load_modules(mut context: Single<Context>, assets: Single<Assets>)
}
};
+ context.modules.insert(*asset_id, module.clone());
+
if !module_source.link_entrypoints.is_empty() {
assert!(context.programs.get(asset_id).is_none());
- let Some(vertex_shader_entry_point) = module.get_entry_point("vertex_main")
- else {
- tracing::error!(
- "Shader module does not contain a vertex shader entry point"
- );
- continue;
- };
-
- let Some(fragment_shader_entry_point) =
- module.get_entry_point("fragment_main")
- else {
- tracing::error!(
- "Shader module don't contain a fragment_main entry point"
- );
- continue;
- };
-
- let shader_program = match context.compose_into_program(
- &[&module],
- &[&vertex_shader_entry_point, &fragment_shader_entry_point],
- ) {
- Ok(shader_program) => shader_program,
- Err(err) => {
- tracing::error!("Failed to compose shader into program: {err}");
+ let entry_points = match module_source
+ .link_entrypoints
+ .iter()
+ .filter_map(|entrypoint_flag| {
+ let entrypoint_name = bitflags_match!(entrypoint_flag, {
+ EntrypointFlags::VERTEX => Some("vertex_main"),
+ EntrypointFlags::FRAGMENT => Some("fragment_main"),
+ _ => None
+ })?;
+
+ let Some(entry_point) = module.get_entry_point(entrypoint_name)
+ else {
+ return Some(Err(EntrypointNotFoundError { entrypoint_name }));
+ };
+
+ Some(Ok(entry_point))
+ })
+ .collect::<Result<Vec<_>, EntrypointNotFoundError>>()
+ {
+ Ok(entry_points) => entry_points,
+ Err(EntrypointNotFoundError { entrypoint_name }) => {
+ tracing::error!(
+ "Shader module does not have a '{entrypoint_name}' entry point"
+ );
continue;
}
};
+ let shader_program =
+ match context.compose_into_program([module], entry_points) {
+ Ok(shader_program) => shader_program,
+ Err(err) => {
+ tracing::error!("Failed to compose shader into program: {err}");
+ continue;
+ }
+ };
+
let linked_shader_program = match shader_program.link() {
Ok(linked_shader_program) => linked_shader_program,
Err(err) => {
@@ -796,10 +1070,32 @@ fn load_modules(mut context: Single<Context>, assets: Single<Assets>)
}
};
- context.programs.insert(*asset_id, linked_shader_program);
- }
+ let vertex_subset = if module_source
+ .link_entrypoints
+ .contains(EntrypointFlags::VERTEX)
+ {
+ VertexSubset::new(
+ &shader_program
+ .reflection(0)
+ .expect("Not possible")
+ .get_entry_point_by_name("vertex_main")
+ .expect("Not possible"),
+ )
+ .inspect_err(|err| {
+ tracing::error!(
+ "Failed to create vertex subset for shader {asset_label:?}: {err}"
+ );
+ })
+ .ok()
+ } else {
+ None
+ };
- context.modules.insert(*asset_id, module);
+ context.programs.insert(
+ *asset_id,
+ (linked_shader_program, ProgramMetadata { vertex_subset }),
+ );
+ }
}
}
@@ -816,3 +1112,9 @@ fn load_module(
Ok(Module { inner: module })
}
+
+#[derive(Debug)]
+struct EntrypointNotFoundError
+{
+ entrypoint_name: &'static str,
+}