summaryrefslogtreecommitdiff
path: root/engine/src/shader.rs
diff options
context:
space:
mode:
Diffstat (limited to 'engine/src/shader.rs')
-rw-r--r--engine/src/shader.rs1120
1 files changed, 1120 insertions, 0 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs
new file mode 100644
index 0000000..ae0be1d
--- /dev/null
+++ b/engine/src/shader.rs
@@ -0,0 +1,1120 @@
+use std::alloc::Layout;
+use std::any::type_name;
+use std::borrow::Cow;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::path::Path;
+use std::str::Utf8Error;
+
+use bitflags::{bitflags, bitflags_match};
+use ecs::pair::{ChildOf, Pair};
+use ecs::phase::{Phase, START as START_PHASE};
+use ecs::sole::Single;
+use ecs::{Component, Sole, declare_entity};
+use shader_slang::{
+ Blob as SlangBlob,
+ ComponentType as SlangComponentType,
+ DebugInfoLevel as SlangDebugInfoLevel,
+ EntryPoint as SlangEntryPoint,
+ GlobalSession as SlangGlobalSession,
+ Module as SlangModule,
+ ParameterCategory as SlangParameterCategory,
+ Session as SlangSession,
+ TypeKind as SlangTypeKind,
+};
+
+use crate::asset::{
+ Assets,
+ Event as AssetEvent,
+ HANDLE_ASSETS_PHASE,
+ Handle as AssetHandle,
+ Id as AssetId,
+ Submitter as AssetSubmitter,
+};
+use crate::builder;
+use crate::mesh::Vertex;
+use crate::reflection::{
+ Struct as StructReflection,
+ StructField as StructFieldReflection,
+ With,
+};
+use crate::renderer::PRE_RENDER_PHASE;
+use crate::shader::default::{
+ ASSET_LABEL,
+ enqueue_set_shader_bindings as default_shader_enqueue_set_shader_bindings,
+};
+
+pub mod cursor;
+pub mod default;
+
+/// The vertex parameter of a vertex entrypoint function in a shader should have this
+/// semantic name.
+pub const VERTEX_PARAM_SEMANTIC_NAME: &str = "VERTEX";
+
+#[derive(Debug, Clone, Component)]
+pub struct Shader
+{
+ pub asset_handle: AssetHandle<ModuleSource>,
+}
+
+/// Shader module.
+#[derive(Debug)]
+pub struct ModuleSource
+{
+ pub name: Cow<'static, str>,
+ pub file_path: Cow<'static, Path>,
+ pub source: Cow<'static, str>,
+ pub link_entrypoints: EntrypointFlags,
+}
+
+bitflags! {
+ #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
+ pub struct EntrypointFlags: usize
+ {
+ const FRAGMENT = 1 << 0;
+ const VERTEX = 1 << 1;
+ }
+}
+
+#[derive(Clone)]
+pub struct Module
+{
+ inner: SlangModule,
+}
+
+impl Module
+{
+ pub fn entry_points(&self) -> impl ExactSizeIterator<Item = EntryPoint>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPoint { inner: entry_point })
+ }
+
+ pub fn get_entry_point(&self, entry_point: &str) -> Option<EntryPoint>
+ {
+ let entry_point = self.inner.find_entry_point_by_name(entry_point)?;
+
+ Some(EntryPoint { inner: entry_point })
+ }
+}
+
+pub struct EntryPoint
+{
+ inner: SlangEntryPoint,
+}
+
+impl EntryPoint
+{
+ pub fn function(&self) -> FunctionReflection<'_>
+ {
+ FunctionReflection {
+ inner: self.inner.function_reflection(),
+ }
+ }
+}
+
+pub struct FunctionReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Function,
+}
+
+impl<'a> FunctionReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+}
+
+pub struct EntryPointReflection<'a>
+{
+ inner: &'a shader_slang::reflection::EntryPoint,
+}
+
+impl<'a> EntryPointReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+
+ pub fn name_override(&self) -> Option<&str>
+ {
+ self.inner.name_override()
+ }
+
+ pub fn stage(&self) -> Stage
+ {
+ Stage::from_slang_stage(self.inner.stage())
+ }
+
+ pub fn parameters(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .parameters()
+ .map(|param| VariableLayout { inner: param })
+ }
+
+ pub fn var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout { inner: self.inner.var_layout()? })
+ }
+}
+
+#[derive(Clone)]
+pub struct Program
+{
+ inner: SlangComponentType,
+}
+
+impl Program
+{
+ pub fn link(&self) -> Result<Program, Error>
+ {
+ let linked_program = self.inner.link()?;
+
+ Ok(Program { inner: linked_program })
+ }
+
+ pub fn get_entry_point_code(&self, entry_point_index: u32) -> Result<Blob, Error>
+ {
+ let blob = self.inner.entry_point_code(entry_point_index.into(), 0)?;
+
+ Ok(Blob { inner: blob })
+ }
+
+ pub fn reflection(&self, target: u32) -> Result<ProgramReflection<'_>, Error>
+ {
+ let reflection = self.inner.layout(target as i64)?;
+
+ Ok(ProgramReflection { inner: reflection })
+ }
+}
+
+impl Debug for Program
+{
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ formatter
+ .debug_struct(type_name::<Self>())
+ .finish_non_exhaustive()
+ }
+}
+
+pub struct ProgramReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Shader,
+}
+
+impl<'a> ProgramReflection<'a>
+{
+ pub fn get_entry_point_by_index(&self, index: u32)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.entry_point_by_index(index)?,
+ })
+ }
+
+ pub fn get_entry_point_by_name(&self, name: &str)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.find_entry_point_by_name(name)?,
+ })
+ }
+
+ pub fn entry_points(&self)
+ -> impl ExactSizeIterator<Item = EntryPointReflection<'a>>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPointReflection { inner: entry_point })
+ }
+
+ pub fn global_params_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self.inner.global_params_type_layout()?,
+ })
+ }
+
+ pub fn global_params_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout {
+ inner: self.inner.global_params_var_layout()?,
+ })
+ }
+
+ pub fn get_type(&self, name: &str) -> Option<TypeReflection<'a>>
+ {
+ Some(TypeReflection {
+ inner: self.inner.find_type_by_name(name)?,
+ })
+ }
+
+ pub fn get_type_layout(&self, ty: &TypeReflection<'a>) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self
+ .inner
+ .type_layout(&ty.inner, shader_slang::LayoutRules::Default)?,
+ })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct VariableLayout<'a>
+{
+ inner: &'a shader_slang::reflection::VariableLayout,
+}
+
+impl<'a> VariableLayout<'a>
+{
+ pub fn name(&self) -> Option<&'a str>
+ {
+ self.inner.name()
+ }
+
+ pub fn semantic_name(&self) -> Option<&str>
+ {
+ self.inner.semantic_name()
+ }
+
+ pub fn binding_index(&self) -> u32
+ {
+ self.inner
+ .offset(shader_slang::ParameterCategory::DescriptorTableSlot) as u32
+
+ // self.inner.binding_index()
+ }
+
+ pub fn varying_input_offset(&self) -> Option<usize>
+ {
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::VaryingInput)
+ {
+ return None;
+ }
+
+ Some(self.inner.offset(SlangParameterCategory::VaryingInput))
+ }
+
+ pub fn binding_space(&self) -> u32
+ {
+ self.inner.binding_space()
+ }
+
+ pub fn semantic_index(&self) -> usize
+ {
+ self.inner.semantic_index()
+ }
+
+ pub fn offset(&self) -> usize
+ {
+ self.inner.offset(shader_slang::ParameterCategory::Uniform)
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout { inner: self.inner.type_layout()? })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct TypeLayout<'a>
+{
+ inner: &'a shader_slang::reflection::TypeLayout,
+}
+
+impl<'a> TypeLayout<'a>
+{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+
+ pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>>
+ {
+ let index = self.inner.find_field_index_by_name(name);
+
+ if index < 0 {
+ return None;
+ }
+
+ let index = u32::try_from(index.cast_unsigned()).expect("Should not happend");
+
+ let field = self.inner.field_by_index(index)?;
+
+ Some(VariableLayout { inner: field })
+ }
+
+ pub fn parameter_category(&self) -> ParameterCategory
+ {
+ ParameterCategory::from_slang_parameter_category(self.inner.parameter_category())
+ }
+
+ pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64
+ {
+ self.inner.binding_range_descriptor_set_index(index)
+ }
+
+ pub fn get_field_binding_range_offset_by_name(&self, name: &str) -> Option<u64>
+ {
+ let field_index = self.inner.find_field_index_by_name(name);
+
+ if field_index < 0 {
+ return None;
+ }
+
+ let field_binding_range_offset =
+ self.inner.field_binding_range_offset(field_index);
+
+ if field_binding_range_offset < 0 {
+ return None;
+ }
+
+ Some(field_binding_range_offset.cast_unsigned())
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn fields(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .fields()
+ .map(|field| VariableLayout { inner: field })
+ }
+
+ pub fn field_cnt(&self) -> u32
+ {
+ self.inner.field_count()
+ }
+
+ pub fn element_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ self.inner
+ .element_type_layout()
+ .map(|type_layout| TypeLayout { inner: type_layout })
+ }
+
+ pub fn element_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .element_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn container_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .container_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn uniform_size(&self) -> Option<usize>
+ {
+ // tracing::debug!(
+ // "uniform_size: {:?} categories: {:?}",
+ // self.inner.name(),
+ // self.inner.categories().collect::<Vec<_>>(),
+ // );
+
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::Uniform)
+ {
+ return None;
+ }
+
+ // let category = self.inner.categories().next().unwrap();
+
+ // println!(
+ // "AARGH size Category: {category:?} Category count: {}",
+ // self.inner.category_count()
+ // );
+
+ // Some(self.inner.size(category))
+
+ Some(self.inner.size(SlangParameterCategory::Uniform))
+ }
+
+ pub fn stride(&self) -> usize
+ {
+ self.inner.stride(self.inner.categories().next().unwrap())
+ }
+}
+
+pub struct TypeReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Type,
+}
+
+impl TypeReflection<'_>
+{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum TypeKind
+{
+ None,
+ Struct,
+ Enum,
+ Array,
+ Matrix,
+ Vector,
+ Scalar,
+ ConstantBuffer,
+ Resource,
+ SamplerState,
+ TextureBuffer,
+ ShaderStorageBuffer,
+ ParameterBlock,
+ GenericTypeParameter,
+ Interface,
+ OutputStream,
+ MeshOutput,
+ Specialized,
+ Feedback,
+ Pointer,
+ DynamicResource,
+ Count,
+}
+
+impl TypeKind
+{
+ fn from_slang_type_kind(type_kind: SlangTypeKind) -> Self
+ {
+ match type_kind {
+ SlangTypeKind::None => Self::None,
+ SlangTypeKind::Struct => Self::Struct,
+ SlangTypeKind::Enum => Self::Enum,
+ SlangTypeKind::Array => Self::Array,
+ SlangTypeKind::Matrix => Self::Matrix,
+ SlangTypeKind::Vector => Self::Vector,
+ SlangTypeKind::Scalar => Self::Scalar,
+ SlangTypeKind::ConstantBuffer => Self::ConstantBuffer,
+ SlangTypeKind::Resource => Self::Resource,
+ SlangTypeKind::SamplerState => Self::SamplerState,
+ SlangTypeKind::TextureBuffer => Self::TextureBuffer,
+ SlangTypeKind::ShaderStorageBuffer => Self::ShaderStorageBuffer,
+ SlangTypeKind::ParameterBlock => Self::ParameterBlock,
+ SlangTypeKind::GenericTypeParameter => Self::GenericTypeParameter,
+ SlangTypeKind::Interface => Self::Interface,
+ SlangTypeKind::OutputStream => Self::OutputStream,
+ SlangTypeKind::MeshOutput => Self::MeshOutput,
+ SlangTypeKind::Specialized => Self::Specialized,
+ SlangTypeKind::Feedback => Self::Feedback,
+ SlangTypeKind::Pointer => Self::Pointer,
+ SlangTypeKind::DynamicResource => Self::DynamicResource,
+ SlangTypeKind::Count => Self::Count,
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
+pub enum ParameterCategory
+{
+ None,
+ Mixed,
+ ConstantBuffer,
+ ShaderResource,
+ UnorderedAccess,
+ VaryingInput,
+ VaryingOutput,
+ SamplerState,
+ Uniform,
+ DescriptorTableSlot,
+ SpecializationConstant,
+ PushConstantBuffer,
+ RegisterSpace,
+ Generic,
+ RayPayload,
+ HitAttributes,
+ CallablePayload,
+ ShaderRecord,
+ ExistentialTypeParam,
+ ExistentialObjectParam,
+ SubElementRegisterSpace,
+ Subpass,
+ MetalArgumentBufferElement,
+ MetalAttribute,
+ MetalPayload,
+ Count,
+}
+
+impl ParameterCategory
+{
+ fn from_slang_parameter_category(parameter_category: SlangParameterCategory) -> Self
+ {
+ match parameter_category {
+ SlangParameterCategory::None => Self::None,
+ SlangParameterCategory::Mixed => Self::Mixed,
+ SlangParameterCategory::ConstantBuffer => Self::ConstantBuffer,
+ SlangParameterCategory::ShaderResource => Self::ShaderResource,
+ SlangParameterCategory::UnorderedAccess => Self::UnorderedAccess,
+ SlangParameterCategory::VaryingInput => Self::VaryingInput,
+ SlangParameterCategory::VaryingOutput => Self::VaryingOutput,
+ SlangParameterCategory::SamplerState => Self::SamplerState,
+ SlangParameterCategory::Uniform => Self::Uniform,
+ SlangParameterCategory::DescriptorTableSlot => Self::DescriptorTableSlot,
+ SlangParameterCategory::SpecializationConstant => {
+ Self::SpecializationConstant
+ }
+ SlangParameterCategory::PushConstantBuffer => Self::PushConstantBuffer,
+ SlangParameterCategory::RegisterSpace => Self::RegisterSpace,
+ SlangParameterCategory::Generic => Self::Generic,
+ SlangParameterCategory::RayPayload => Self::RayPayload,
+ SlangParameterCategory::HitAttributes => Self::HitAttributes,
+ SlangParameterCategory::CallablePayload => Self::CallablePayload,
+ SlangParameterCategory::ShaderRecord => Self::ShaderRecord,
+ SlangParameterCategory::ExistentialTypeParam => Self::ExistentialTypeParam,
+ SlangParameterCategory::ExistentialObjectParam => {
+ Self::ExistentialObjectParam
+ }
+ SlangParameterCategory::SubElementRegisterSpace => {
+ Self::SubElementRegisterSpace
+ }
+ SlangParameterCategory::Subpass => Self::Subpass,
+ SlangParameterCategory::MetalArgumentBufferElement => {
+ Self::MetalArgumentBufferElement
+ }
+ SlangParameterCategory::MetalAttribute => Self::MetalAttribute,
+ SlangParameterCategory::MetalPayload => Self::MetalPayload,
+ SlangParameterCategory::Count => Self::Count,
+ }
+ }
+}
+
+pub struct Blob
+{
+ inner: SlangBlob,
+}
+
+impl Blob
+{
+ pub fn as_bytes(&self) -> &[u8]
+ {
+ self.inner.as_slice()
+ }
+
+ pub fn as_str(&self) -> Result<&str, Utf8Error>
+ {
+ self.inner.as_str()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum Stage
+{
+ None,
+ Vertex,
+ Hull,
+ Domain,
+ Geometry,
+ Fragment,
+ Compute,
+ RayGeneration,
+ Intersection,
+ AnyHit,
+ ClosestHit,
+ Miss,
+ Callable,
+ Mesh,
+ Amplification,
+ Dispatch,
+ Count,
+}
+
+impl Stage
+{
+ fn from_slang_stage(stage: shader_slang::Stage) -> Self
+ {
+ match stage {
+ shader_slang::Stage::None => Self::None,
+ shader_slang::Stage::Vertex => Self::Vertex,
+ shader_slang::Stage::Hull => Self::Hull,
+ shader_slang::Stage::Domain => Self::Domain,
+ shader_slang::Stage::Geometry => Self::Geometry,
+ shader_slang::Stage::Fragment => Self::Fragment,
+ shader_slang::Stage::Compute => Self::Compute,
+ shader_slang::Stage::RayGeneration => Self::RayGeneration,
+ shader_slang::Stage::Intersection => Self::Intersection,
+ shader_slang::Stage::AnyHit => Self::AnyHit,
+ shader_slang::Stage::ClosestHit => Self::ClosestHit,
+ shader_slang::Stage::Miss => Self::Miss,
+ shader_slang::Stage::Callable => Self::Callable,
+ shader_slang::Stage::Mesh => Self::Mesh,
+ shader_slang::Stage::Amplification => Self::Amplification,
+ shader_slang::Stage::Dispatch => Self::Dispatch,
+ shader_slang::Stage::Count => Self::Count,
+ }
+ }
+}
+
+builder! {
+#[builder(name = SettingsBuilder, derives=(Debug))]
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct Settings
+{
+ link_entrypoints: EntrypointFlags,
+}
+}
+
+#[derive(Sole)]
+pub struct Context
+{
+ _global_session: SlangGlobalSession,
+ session: SlangSession,
+ modules: HashMap<AssetId, Module>,
+ programs: HashMap<AssetId, (Program, ProgramMetadata)>,
+}
+
+impl Context
+{
+ pub fn get_module(&self, asset_id: &AssetId) -> Option<&Module>
+ {
+ self.modules.get(asset_id)
+ }
+
+ pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program>
+ {
+ self.programs.get(asset_id).map(|(program, _)| program)
+ }
+
+ pub fn get_program_metadata(&self, asset_id: &AssetId) -> Option<&ProgramMetadata>
+ {
+ self.programs
+ .get(asset_id)
+ .map(|(_, program_metadata)| program_metadata)
+ }
+
+ pub fn compose_into_program(
+ &self,
+ modules: impl IntoIterator<Item = Module>,
+ entry_points: impl IntoIterator<Item = EntryPoint>,
+ ) -> Result<Program, Error>
+ {
+ let components =
+ modules
+ .into_iter()
+ .map(|module| SlangComponentType::from(module.inner.clone()))
+ .chain(entry_points.into_iter().map(|entry_point| {
+ SlangComponentType::from(entry_point.inner.clone())
+ }))
+ .collect::<Vec<_>>();
+
+ let program = self.session.create_composite_component_type(&components)?;
+
+ Ok(Program { inner: program })
+ }
+}
+
+pub struct ProgramMetadata
+{
+ pub vertex_subset: Option<VertexSubset>,
+}
+
+#[derive(Debug)]
+pub struct VertexSubset
+{
+ pub layout: Layout,
+ pub fields: [Option<VertexSubsetField>; const {
+ Vertex::REFLECTION.as_struct().unwrap().fields.len()
+ }],
+}
+
+impl VertexSubset
+{
+ pub fn new(
+ vs_entrypoint: &EntryPointReflection<'_>,
+ ) -> Result<Self, VertexSubsetError>
+ {
+ const VERTEX_REFLECTION: &StructReflection =
+ const { Vertex::REFLECTION.as_struct().unwrap() };
+
+ if vs_entrypoint.stage() != Stage::Vertex {
+ return Err(VertexSubsetError::EntrypointNotInVertexStage);
+ }
+
+ let vs_entrypoint_vertex_param = vs_entrypoint
+ .parameters()
+ .find(|param| param.semantic_name() == Some(VERTEX_PARAM_SEMANTIC_NAME))
+ .ok_or(VertexSubsetError::EntrypointMissingVertexParam)?;
+
+ let vs_entrypoint_vertex_param = vs_entrypoint_vertex_param
+ .type_layout()
+ .expect("Not possible");
+
+ if vs_entrypoint_vertex_param.parameter_category()
+ != ParameterCategory::VaryingInput
+ {
+ return Err(VertexSubsetError::EntryPointVertexParamNotVaryingInput);
+ }
+
+ if vs_entrypoint_vertex_param.kind() != TypeKind::Struct {
+ return Err(VertexSubsetError::EntrypointVertexTypeNotStruct);
+ }
+
+ if let Some(unknown_vertex_field_name) = vs_entrypoint_vertex_param
+ .fields()
+ .find_map(|vertex_param_field| {
+ let vertex_param_field_name =
+ vertex_param_field.name().expect("Not possible");
+
+ if VERTEX_REFLECTION
+ .fields
+ .iter()
+ .all(|vertex_field| vertex_field.name != vertex_param_field_name)
+ {
+ return Some(vertex_param_field_name);
+ }
+
+ None
+ })
+ {
+ return Err(VertexSubsetError::EntrypointVertexTypeHasUnknownField {
+ field_name: unknown_vertex_field_name.to_string(),
+ });
+ }
+
+ let mut layout = Layout::new::<()>();
+
+ let mut fields = [const { None }; const { VERTEX_REFLECTION.fields.len() }];
+
+ for vertex_field in const { VERTEX_REFLECTION.fields } {
+ let Some(vertex_field_var_layout) =
+ vs_entrypoint_vertex_param.get_field_by_name(vertex_field.name)
+ else {
+ continue;
+ };
+
+ let (new_layout, vertex_field_offset) =
+ layout.extend(vertex_field.layout).expect("Not possible");
+
+ layout = new_layout;
+
+ fields[vertex_field.index] = Some(VertexSubsetField {
+ offset: vertex_field_offset,
+ reflection: vertex_field,
+ varying_input_offset: vertex_field_var_layout
+ .varying_input_offset()
+ .expect("Not possible"),
+ });
+ }
+
+ layout = layout.pad_to_align();
+
+ Ok(Self { layout, fields })
+ }
+}
+
+#[derive(Debug)]
+pub struct VertexSubsetField
+{
+ pub offset: usize,
+ pub reflection: &'static StructFieldReflection,
+ pub varying_input_offset: usize,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum VertexSubsetError
+{
+ #[error("Entrypoint is not in vertex stage")]
+ EntrypointNotInVertexStage,
+
+ #[error(
+ "Entrypoint does not have a vertex parameter (parameter with semantic name {})",
+ VERTEX_PARAM_SEMANTIC_NAME
+ )]
+ EntrypointMissingVertexParam,
+
+ #[error("Entrypoint vertex parameter is not a varying input")]
+ EntryPointVertexParamNotVaryingInput,
+
+ #[error("Entrypoint vertex type is not a struct")]
+ EntrypointVertexTypeNotStruct,
+
+ #[error("Entrypoint vertex type has unknown field {field_name}")]
+ EntrypointVertexTypeHasUnknownField
+ {
+ field_name: String
+ },
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub struct Error(#[from] shader_slang::Error);
+
+pub(crate) fn add_asset_importers(assets: &mut Assets)
+{
+ assets.set_importer::<_, _>(["slang"], import_slang_asset);
+}
+
+fn import_slang_asset(
+ asset_submitter: &mut AssetSubmitter<'_>,
+ file_path: &Path,
+ settings: Option<&'_ Settings>,
+) -> Result<(), ImportError>
+{
+ let file_name = file_path
+ .file_name()
+ .ok_or(ImportError::NoPathFileName)?
+ .to_str()
+ .ok_or(ImportError::PathFileNameNotUtf8)?;
+
+ let file_path_canonicalized = file_path
+ .canonicalize()
+ .map_err(ImportError::CanonicalizePathFailed)?;
+
+ asset_submitter.submit_store(ModuleSource {
+ name: file_name.to_owned().into(),
+ file_path: file_path_canonicalized.into(),
+ source: std::fs::read_to_string(file_path)
+ .map_err(ImportError::ReadFileFailed)?
+ .into(),
+ link_entrypoints: settings
+ .map(|settings| settings.link_entrypoints)
+ .unwrap_or_default(),
+ });
+
+ Ok(())
+}
+
+#[derive(Debug, thiserror::Error)]
+enum ImportError
+{
+ #[error("Failed to read file")]
+ ReadFileFailed(#[source] std::io::Error),
+
+ #[error("Asset path does not have a file name")]
+ NoPathFileName,
+
+ #[error("Asset path file name is not valid UTF8")]
+ PathFileNameNotUtf8,
+
+ #[error("Failed to canonicalize asset path")]
+ CanonicalizePathFailed(#[source] std::io::Error),
+}
+
+declare_entity!(
+ IMPORT_SHADERS_PHASE,
+ (
+ Phase,
+ Pair::builder()
+ .relation::<ChildOf>()
+ .target_id(*HANDLE_ASSETS_PHASE)
+ .build()
+ )
+);
+
+pub(crate) struct Extension;
+
+impl ecs::extension::Extension for Extension
+{
+ fn collect(self, mut collector: ecs::extension::Collector<'_>)
+ {
+ let Some(global_session) = SlangGlobalSession::new() else {
+ tracing::error!("Unable to create global shader-slang session");
+ return;
+ };
+
+ let session_options = shader_slang::CompilerOptions::default()
+ .optimization(shader_slang::OptimizationLevel::None)
+ .matrix_layout_column(true)
+ .debug_information(SlangDebugInfoLevel::Maximal)
+ .no_mangle(true);
+
+ let target_desc = shader_slang::TargetDesc::default()
+ .format(shader_slang::CompileTarget::Glsl)
+ // .format(shader_slang::CompileTarget::Spirv)
+ .profile(global_session.find_profile("glsl_330"));
+ // .profile(global_session.find_profile("spirv_1_5"));
+
+ let targets = [target_desc];
+
+ let session_desc = shader_slang::SessionDesc::default()
+ .targets(&targets)
+ .search_paths(&[""])
+ .options(&session_options);
+
+ let Some(session) = global_session.create_session(&session_desc) else {
+ tracing::error!("Failed to create shader-slang session");
+ return;
+ };
+
+ collector
+ .add_sole(Context {
+ _global_session: global_session,
+ session,
+ modules: HashMap::new(),
+ programs: HashMap::new(),
+ })
+ .ok();
+
+ collector.add_declared_entity(&IMPORT_SHADERS_PHASE);
+
+ collector.add_system(*START_PHASE, initialize);
+ collector.add_system(*IMPORT_SHADERS_PHASE, load_modules);
+
+ collector.add_system(
+ *PRE_RENDER_PHASE,
+ default_shader_enqueue_set_shader_bindings,
+ );
+ }
+}
+
+fn initialize(mut assets: Single<Assets>)
+{
+ assets.store_with_label(
+ ASSET_LABEL.clone(),
+ ModuleSource {
+ name: "default_shader.slang".into(),
+ file_path: Path::new("@engine/default_shader").into(),
+ source: include_str!("../res/default_shader.slang").into(),
+ link_entrypoints: EntrypointFlags::VERTEX | EntrypointFlags::FRAGMENT,
+ },
+ );
+}
+
+#[tracing::instrument(skip_all)]
+fn load_modules(mut context: Single<Context>, assets: Single<Assets>)
+{
+ for AssetEvent::Stored(asset_id, asset_label) in assets.events().last_tick_events() {
+ let asset_handle = AssetHandle::<ModuleSource>::from_id(*asset_id);
+
+ if !assets.is_loaded_and_has_type(&asset_handle) {
+ continue;
+ }
+
+ let Some(module_source) = assets.get(&asset_handle) else {
+ unreachable!();
+ };
+
+ tracing::debug!(asset_label=?asset_label, "Loading shader module");
+
+ let module = match load_module(&context.session, module_source) {
+ Ok(module) => module,
+ Err(err) => {
+ tracing::error!("Failed to load shader module: {err}");
+ continue;
+ }
+ };
+
+ context.modules.insert(*asset_id, module.clone());
+
+ if !module_source.link_entrypoints.is_empty() {
+ assert!(context.programs.get(asset_id).is_none());
+
+ let entry_points = match module_source
+ .link_entrypoints
+ .iter()
+ .filter_map(|entrypoint_flag| {
+ let entrypoint_name = bitflags_match!(entrypoint_flag, {
+ EntrypointFlags::VERTEX => Some("vertex_main"),
+ EntrypointFlags::FRAGMENT => Some("fragment_main"),
+ _ => None
+ })?;
+
+ let Some(entry_point) = module.get_entry_point(entrypoint_name)
+ else {
+ return Some(Err(EntrypointNotFoundError { entrypoint_name }));
+ };
+
+ Some(Ok(entry_point))
+ })
+ .collect::<Result<Vec<_>, EntrypointNotFoundError>>()
+ {
+ Ok(entry_points) => entry_points,
+ Err(EntrypointNotFoundError { entrypoint_name }) => {
+ tracing::error!(
+ "Shader module does not have a '{entrypoint_name}' entry point"
+ );
+ continue;
+ }
+ };
+
+ let shader_program =
+ match context.compose_into_program([module], entry_points) {
+ Ok(shader_program) => shader_program,
+ Err(err) => {
+ tracing::error!("Failed to compose shader into program: {err}");
+ continue;
+ }
+ };
+
+ let linked_shader_program = match shader_program.link() {
+ Ok(linked_shader_program) => linked_shader_program,
+ Err(err) => {
+ tracing::error!("Failed to link shader: {err}");
+ continue;
+ }
+ };
+
+ let vertex_subset = if module_source
+ .link_entrypoints
+ .contains(EntrypointFlags::VERTEX)
+ {
+ VertexSubset::new(
+ &shader_program
+ .reflection(0)
+ .expect("Not possible")
+ .get_entry_point_by_name("vertex_main")
+ .expect("Not possible"),
+ )
+ .inspect_err(|err| {
+ tracing::error!(
+ "Failed to create vertex subset for shader {asset_label:?}: {err}"
+ );
+ })
+ .ok()
+ } else {
+ None
+ };
+
+ context.programs.insert(
+ *asset_id,
+ (linked_shader_program, ProgramMetadata { vertex_subset }),
+ );
+ }
+ }
+}
+
+fn load_module(
+ session: &SlangSession,
+ module_source: &ModuleSource,
+) -> Result<Module, Error>
+{
+ let module = session.load_module_from_source_string(
+ &module_source.name,
+ &module_source.file_path.to_string_lossy(),
+ &module_source.source,
+ )?;
+
+ Ok(Module { inner: module })
+}
+
+#[derive(Debug)]
+struct EntrypointNotFoundError
+{
+ entrypoint_name: &'static str,
+}