summaryrefslogtreecommitdiff
path: root/engine/src/shader.rs
diff options
context:
space:
mode:
Diffstat (limited to 'engine/src/shader.rs')
-rw-r--r--engine/src/shader.rs1157
1 files changed, 1157 insertions, 0 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs
new file mode 100644
index 0000000..c4bd709
--- /dev/null
+++ b/engine/src/shader.rs
@@ -0,0 +1,1157 @@
+use std::any::type_name;
+use std::borrow::Cow;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::path::Path;
+use std::str::Utf8Error;
+
+use bitflags::{bitflags, bitflags_match};
+use ecs::pair::{ChildOf, Pair};
+use ecs::phase::{POST_UPDATE as POST_UPDATE_PHASE, Phase, START as START_PHASE};
+use ecs::sole::Single;
+use ecs::{Component, Sole, declare_entity};
+use shader_slang::{
+ Blob as SlangBlob,
+ ComponentType as SlangComponentType,
+ DebugInfoLevel as SlangDebugInfoLevel,
+ EntryPoint as SlangEntryPoint,
+ GlobalSession as SlangGlobalSession,
+ Module as SlangModule,
+ ParameterCategory as SlangParameterCategory,
+ ScalarType as SlangScalarType,
+ Session as SlangSession,
+ TypeKind as SlangTypeKind,
+};
+
+use crate::asset::{
+ Assets,
+ Event as AssetEvent,
+ HANDLE_ASSETS_PHASE,
+ Handle as AssetHandle,
+ Id as AssetId,
+ Submitter as AssetSubmitter,
+};
+use crate::builder;
+use crate::shader::default::{
+ ASSET_LABEL,
+ enqueue_set_shader_bindings as default_shader_enqueue_set_shader_bindings,
+};
+
+pub mod cursor;
+pub mod default;
+
+/// The vertex parameter of a vertex entrypoint function in a shader should have this
+/// semantic name.
+pub const VERTEX_PARAM_SEMANTIC_NAME: &str = "VERTEX";
+
+#[derive(Debug, Clone, Component)]
+pub struct Shader
+{
+ pub asset_handle: AssetHandle<ModuleSource>,
+}
+
+/// Shader module.
+#[derive(Debug)]
+pub struct ModuleSource
+{
+ pub name: Cow<'static, str>,
+ pub file_path: Cow<'static, Path>,
+ pub source: Cow<'static, str>,
+ pub link_entrypoints: EntrypointFlags,
+}
+
+bitflags! {
+ #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
+ pub struct EntrypointFlags: usize
+ {
+ const FRAGMENT = 1 << 0;
+ const VERTEX = 1 << 1;
+ }
+}
+
+#[derive(Clone)]
+pub struct Module
+{
+ inner: SlangModule,
+}
+
+impl Module
+{
+ pub fn entry_points(&self) -> impl ExactSizeIterator<Item = EntryPoint>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPoint { inner: entry_point })
+ }
+
+ pub fn get_entry_point(&self, entry_point: &str) -> Option<EntryPoint>
+ {
+ let entry_point = self.inner.find_entry_point_by_name(entry_point)?;
+
+ Some(EntryPoint { inner: entry_point })
+ }
+}
+
+pub struct EntryPoint
+{
+ inner: SlangEntryPoint,
+}
+
+impl EntryPoint
+{
+ pub fn function(&self) -> FunctionReflection<'_>
+ {
+ FunctionReflection {
+ inner: self.inner.function_reflection(),
+ }
+ }
+}
+
+pub struct FunctionReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Function,
+}
+
+impl<'a> FunctionReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+}
+
+pub struct EntryPointReflection<'a>
+{
+ inner: &'a shader_slang::reflection::EntryPoint,
+}
+
+impl<'a> EntryPointReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+
+ pub fn name_override(&self) -> Option<&str>
+ {
+ self.inner.name_override()
+ }
+
+ pub fn stage(&self) -> Stage
+ {
+ Stage::from_slang_stage(self.inner.stage())
+ }
+
+ pub fn parameters(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .parameters()
+ .map(|param| VariableLayout { inner: param })
+ }
+
+ pub fn var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout { inner: self.inner.var_layout()? })
+ }
+}
+
+#[derive(Clone)]
+pub struct Program
+{
+ inner: SlangComponentType,
+}
+
+impl Program
+{
+ pub fn link(&self) -> Result<Program, Error>
+ {
+ let linked_program = self.inner.link()?;
+
+ Ok(Program { inner: linked_program })
+ }
+
+ pub fn get_entry_point_code(&self, entry_point_index: u32) -> Result<Blob, Error>
+ {
+ let blob = self.inner.entry_point_code(entry_point_index.into(), 0)?;
+
+ Ok(Blob { inner: blob })
+ }
+
+ pub fn reflection(&self, target: u32) -> Result<ProgramReflection<'_>, Error>
+ {
+ let reflection = self.inner.layout(target as i64)?;
+
+ Ok(ProgramReflection { inner: reflection })
+ }
+}
+
+impl Debug for Program
+{
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ formatter
+ .debug_struct(type_name::<Self>())
+ .finish_non_exhaustive()
+ }
+}
+
+pub struct ProgramReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Shader,
+}
+
+impl<'a> ProgramReflection<'a>
+{
+ pub fn get_entry_point_by_index(&self, index: u32)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.entry_point_by_index(index)?,
+ })
+ }
+
+ pub fn get_entry_point_by_name(&self, name: &str)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.find_entry_point_by_name(name)?,
+ })
+ }
+
+ pub fn entry_points(&self)
+ -> impl ExactSizeIterator<Item = EntryPointReflection<'a>>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPointReflection { inner: entry_point })
+ }
+
+ pub fn global_params_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self.inner.global_params_type_layout()?,
+ })
+ }
+
+ pub fn global_params_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout {
+ inner: self.inner.global_params_var_layout()?,
+ })
+ }
+
+ pub fn get_type(&self, name: &str) -> Option<TypeReflection<'a>>
+ {
+ Some(TypeReflection {
+ inner: self.inner.find_type_by_name(name)?,
+ })
+ }
+
+ pub fn get_type_layout(&self, ty: &TypeReflection<'a>) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self
+ .inner
+ .type_layout(&ty.inner, shader_slang::LayoutRules::Default)?,
+ })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct VariableLayout<'a>
+{
+ inner: &'a shader_slang::reflection::VariableLayout,
+}
+
+impl<'a> VariableLayout<'a>
+{
+ pub fn name(&self) -> Option<&'a str>
+ {
+ self.inner.name()
+ }
+
+ pub fn semantic_name(&self) -> Option<&str>
+ {
+ self.inner.semantic_name()
+ }
+
+ pub fn binding_index(&self) -> u32
+ {
+ self.inner
+ .offset(shader_slang::ParameterCategory::DescriptorTableSlot) as u32
+
+ // self.inner.binding_index()
+ }
+
+ pub fn varying_input_offset(&self) -> Option<usize>
+ {
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::VaryingInput)
+ {
+ return None;
+ }
+
+ Some(self.inner.offset(SlangParameterCategory::VaryingInput))
+ }
+
+ pub fn binding_space(&self) -> u32
+ {
+ self.inner.binding_space()
+ }
+
+ pub fn semantic_index(&self) -> usize
+ {
+ self.inner.semantic_index()
+ }
+
+ pub fn offset(&self) -> usize
+ {
+ self.inner.offset(shader_slang::ParameterCategory::Uniform)
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout { inner: self.inner.type_layout()? })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct TypeLayout<'a>
+{
+ inner: &'a shader_slang::reflection::TypeLayout,
+}
+
+impl<'a> TypeLayout<'a>
+{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+
+ pub fn scalar_type(&self) -> Option<ScalarType>
+ {
+ Some(ScalarType::from_slang_scalar_type(
+ self.inner.scalar_type()?,
+ ))
+ }
+
+ pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>>
+ {
+ let index = self.inner.find_field_index_by_name(name);
+
+ if index < 0 {
+ return None;
+ }
+
+ let index = u32::try_from(index.cast_unsigned()).expect("Should not happend");
+
+ let field = self.inner.field_by_index(index)?;
+
+ Some(VariableLayout { inner: field })
+ }
+
+ pub fn parameter_category(&self) -> ParameterCategory
+ {
+ ParameterCategory::from_slang_parameter_category(self.inner.parameter_category())
+ }
+
+ pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64
+ {
+ self.inner.binding_range_descriptor_set_index(index)
+ }
+
+ pub fn get_field_binding_range_offset_by_name(&self, name: &str) -> Option<u64>
+ {
+ let field_index = self.inner.find_field_index_by_name(name);
+
+ if field_index < 0 {
+ return None;
+ }
+
+ let field_binding_range_offset =
+ self.inner.field_binding_range_offset(field_index);
+
+ if field_binding_range_offset < 0 {
+ return None;
+ }
+
+ Some(field_binding_range_offset.cast_unsigned())
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn fields(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .fields()
+ .map(|field| VariableLayout { inner: field })
+ }
+
+ pub fn field_cnt(&self) -> u32
+ {
+ self.inner.field_count()
+ }
+
+ pub fn element_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ self.inner
+ .element_type_layout()
+ .map(|type_layout| TypeLayout { inner: type_layout })
+ }
+
+ pub fn element_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .element_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn container_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .container_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn uniform_size(&self) -> Option<usize>
+ {
+ // tracing::debug!(
+ // "uniform_size: {:?} categories: {:?}",
+ // self.inner.name(),
+ // self.inner.categories().collect::<Vec<_>>(),
+ // );
+
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::Uniform)
+ {
+ return None;
+ }
+
+ // let category = self.inner.categories().next().unwrap();
+
+ // println!(
+ // "AARGH size Category: {category:?} Category count: {}",
+ // self.inner.category_count()
+ // );
+
+ // Some(self.inner.size(category))
+
+ Some(self.inner.size(SlangParameterCategory::Uniform))
+ }
+
+ pub fn stride(&self) -> usize
+ {
+ self.inner.stride(self.inner.categories().next().unwrap())
+ }
+}
+
+pub struct TypeReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Type,
+}
+
+impl TypeReflection<'_>
+{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum TypeKind
+{
+ None,
+ Struct,
+ Enum,
+ Array,
+ Matrix,
+ Vector,
+ Scalar,
+ ConstantBuffer,
+ Resource,
+ SamplerState,
+ TextureBuffer,
+ ShaderStorageBuffer,
+ ParameterBlock,
+ GenericTypeParameter,
+ Interface,
+ OutputStream,
+ MeshOutput,
+ Specialized,
+ Feedback,
+ Pointer,
+ DynamicResource,
+ Count,
+}
+
+impl TypeKind
+{
+ fn from_slang_type_kind(type_kind: SlangTypeKind) -> Self
+ {
+ match type_kind {
+ SlangTypeKind::None => Self::None,
+ SlangTypeKind::Struct => Self::Struct,
+ SlangTypeKind::Enum => Self::Enum,
+ SlangTypeKind::Array => Self::Array,
+ SlangTypeKind::Matrix => Self::Matrix,
+ SlangTypeKind::Vector => Self::Vector,
+ SlangTypeKind::Scalar => Self::Scalar,
+ SlangTypeKind::ConstantBuffer => Self::ConstantBuffer,
+ SlangTypeKind::Resource => Self::Resource,
+ SlangTypeKind::SamplerState => Self::SamplerState,
+ SlangTypeKind::TextureBuffer => Self::TextureBuffer,
+ SlangTypeKind::ShaderStorageBuffer => Self::ShaderStorageBuffer,
+ SlangTypeKind::ParameterBlock => Self::ParameterBlock,
+ SlangTypeKind::GenericTypeParameter => Self::GenericTypeParameter,
+ SlangTypeKind::Interface => Self::Interface,
+ SlangTypeKind::OutputStream => Self::OutputStream,
+ SlangTypeKind::MeshOutput => Self::MeshOutput,
+ SlangTypeKind::Specialized => Self::Specialized,
+ SlangTypeKind::Feedback => Self::Feedback,
+ SlangTypeKind::Pointer => Self::Pointer,
+ SlangTypeKind::DynamicResource => Self::DynamicResource,
+ SlangTypeKind::Count => Self::Count,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+#[non_exhaustive]
+pub enum ScalarType
+{
+ None,
+ Void,
+ Bool,
+ Int32,
+ Uint32,
+ Int64,
+ Uint64,
+ Float16,
+ Float32,
+ Float64,
+ Int8,
+ Uint8,
+ Int16,
+ Uint16,
+ Intptr,
+ Uintptr,
+}
+
+impl ScalarType
+{
+ fn from_slang_scalar_type(scalar_type: SlangScalarType) -> Self
+ {
+ match scalar_type {
+ SlangScalarType::None => Self::None,
+ SlangScalarType::Void => Self::Void,
+ SlangScalarType::Bool => Self::Bool,
+ SlangScalarType::Int32 => Self::Int32,
+ SlangScalarType::Uint32 => Self::Uint32,
+ SlangScalarType::Int64 => Self::Int64,
+ SlangScalarType::Uint64 => Self::Uint64,
+ SlangScalarType::Float16 => Self::Float16,
+ SlangScalarType::Float32 => Self::Float32,
+ SlangScalarType::Float64 => Self::Float64,
+ SlangScalarType::Int8 => Self::Int8,
+ SlangScalarType::Uint8 => Self::Uint8,
+ SlangScalarType::Int16 => Self::Int16,
+ SlangScalarType::Uint16 => Self::Uint16,
+ SlangScalarType::Intptr => Self::Intptr,
+ SlangScalarType::Uintptr => Self::Uintptr,
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
+pub enum ParameterCategory
+{
+ None,
+ Mixed,
+ ConstantBuffer,
+ ShaderResource,
+ UnorderedAccess,
+ VaryingInput,
+ VaryingOutput,
+ SamplerState,
+ Uniform,
+ DescriptorTableSlot,
+ SpecializationConstant,
+ PushConstantBuffer,
+ RegisterSpace,
+ Generic,
+ RayPayload,
+ HitAttributes,
+ CallablePayload,
+ ShaderRecord,
+ ExistentialTypeParam,
+ ExistentialObjectParam,
+ SubElementRegisterSpace,
+ Subpass,
+ MetalArgumentBufferElement,
+ MetalAttribute,
+ MetalPayload,
+ Count,
+}
+
+impl ParameterCategory
+{
+ fn from_slang_parameter_category(parameter_category: SlangParameterCategory) -> Self
+ {
+ match parameter_category {
+ SlangParameterCategory::None => Self::None,
+ SlangParameterCategory::Mixed => Self::Mixed,
+ SlangParameterCategory::ConstantBuffer => Self::ConstantBuffer,
+ SlangParameterCategory::ShaderResource => Self::ShaderResource,
+ SlangParameterCategory::UnorderedAccess => Self::UnorderedAccess,
+ SlangParameterCategory::VaryingInput => Self::VaryingInput,
+ SlangParameterCategory::VaryingOutput => Self::VaryingOutput,
+ SlangParameterCategory::SamplerState => Self::SamplerState,
+ SlangParameterCategory::Uniform => Self::Uniform,
+ SlangParameterCategory::DescriptorTableSlot => Self::DescriptorTableSlot,
+ SlangParameterCategory::SpecializationConstant => {
+ Self::SpecializationConstant
+ }
+ SlangParameterCategory::PushConstantBuffer => Self::PushConstantBuffer,
+ SlangParameterCategory::RegisterSpace => Self::RegisterSpace,
+ SlangParameterCategory::Generic => Self::Generic,
+ SlangParameterCategory::RayPayload => Self::RayPayload,
+ SlangParameterCategory::HitAttributes => Self::HitAttributes,
+ SlangParameterCategory::CallablePayload => Self::CallablePayload,
+ SlangParameterCategory::ShaderRecord => Self::ShaderRecord,
+ SlangParameterCategory::ExistentialTypeParam => Self::ExistentialTypeParam,
+ SlangParameterCategory::ExistentialObjectParam => {
+ Self::ExistentialObjectParam
+ }
+ SlangParameterCategory::SubElementRegisterSpace => {
+ Self::SubElementRegisterSpace
+ }
+ SlangParameterCategory::Subpass => Self::Subpass,
+ SlangParameterCategory::MetalArgumentBufferElement => {
+ Self::MetalArgumentBufferElement
+ }
+ SlangParameterCategory::MetalAttribute => Self::MetalAttribute,
+ SlangParameterCategory::MetalPayload => Self::MetalPayload,
+ SlangParameterCategory::Count => Self::Count,
+ }
+ }
+}
+
+pub struct Blob
+{
+ inner: SlangBlob,
+}
+
+impl Blob
+{
+ pub fn as_bytes(&self) -> &[u8]
+ {
+ self.inner.as_slice()
+ }
+
+ pub fn as_str(&self) -> Result<&str, Utf8Error>
+ {
+ self.inner.as_str()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum Stage
+{
+ None,
+ Vertex,
+ Hull,
+ Domain,
+ Geometry,
+ Fragment,
+ Compute,
+ RayGeneration,
+ Intersection,
+ AnyHit,
+ ClosestHit,
+ Miss,
+ Callable,
+ Mesh,
+ Amplification,
+ Dispatch,
+ Count,
+}
+
+impl Stage
+{
+ fn from_slang_stage(stage: shader_slang::Stage) -> Self
+ {
+ match stage {
+ shader_slang::Stage::None => Self::None,
+ shader_slang::Stage::Vertex => Self::Vertex,
+ shader_slang::Stage::Hull => Self::Hull,
+ shader_slang::Stage::Domain => Self::Domain,
+ shader_slang::Stage::Geometry => Self::Geometry,
+ shader_slang::Stage::Fragment => Self::Fragment,
+ shader_slang::Stage::Compute => Self::Compute,
+ shader_slang::Stage::RayGeneration => Self::RayGeneration,
+ shader_slang::Stage::Intersection => Self::Intersection,
+ shader_slang::Stage::AnyHit => Self::AnyHit,
+ shader_slang::Stage::ClosestHit => Self::ClosestHit,
+ shader_slang::Stage::Miss => Self::Miss,
+ shader_slang::Stage::Callable => Self::Callable,
+ shader_slang::Stage::Mesh => Self::Mesh,
+ shader_slang::Stage::Amplification => Self::Amplification,
+ shader_slang::Stage::Dispatch => Self::Dispatch,
+ shader_slang::Stage::Count => Self::Count,
+ }
+ }
+}
+
+builder! {
+#[builder(name = SettingsBuilder, derives=(Debug))]
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct Settings
+{
+ link_entrypoints: EntrypointFlags,
+}
+}
+
+#[derive(Sole)]
+pub struct Context
+{
+ _global_session: SlangGlobalSession,
+ session: SlangSession,
+ modules: HashMap<AssetId, Module>,
+ programs: HashMap<AssetId, (Program, ProgramMetadata)>,
+}
+
+impl Context
+{
+ pub fn get_module(&self, asset_id: &AssetId) -> Option<&Module>
+ {
+ self.modules.get(asset_id)
+ }
+
+ pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program>
+ {
+ self.programs.get(asset_id).map(|(program, _)| program)
+ }
+
+ pub fn get_program_metadata(&self, asset_id: &AssetId) -> Option<&ProgramMetadata>
+ {
+ self.programs
+ .get(asset_id)
+ .map(|(_, program_metadata)| program_metadata)
+ }
+
+ pub fn compose_into_program(
+ &self,
+ modules: impl IntoIterator<Item = Module>,
+ entry_points: impl IntoIterator<Item = EntryPoint>,
+ ) -> Result<Program, Error>
+ {
+ let components =
+ modules
+ .into_iter()
+ .map(|module| SlangComponentType::from(module.inner.clone()))
+ .chain(entry_points.into_iter().map(|entry_point| {
+ SlangComponentType::from(entry_point.inner.clone())
+ }))
+ .collect::<Vec<_>>();
+
+ let program = self.session.create_composite_component_type(&components)?;
+
+ Ok(Program { inner: program })
+ }
+}
+
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct ProgramMetadata
+{
+ /// If the program has a entry point in the vertex stage, this field will contain a
+ /// description of the vertex type passed to the entry point.
+ pub vertex_desc: Option<VertexDescription>,
+}
+
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct VertexDescription
+{
+ pub fields: Box<[VertexFieldDescription]>,
+}
+
+impl VertexDescription
+{
+ pub fn new(
+ vs_entrypoint: &EntryPointReflection<'_>,
+ ) -> Result<Self, VertexDescriptionError>
+ {
+ if vs_entrypoint.stage() != Stage::Vertex {
+ return Err(VertexDescriptionError::EntrypointNotInVertexStage);
+ }
+
+ let vs_entrypoint_vertex_param = vs_entrypoint
+ .parameters()
+ .find(|param| param.semantic_name() == Some(VERTEX_PARAM_SEMANTIC_NAME))
+ .ok_or(VertexDescriptionError::EntrypointMissingVertexParam)?;
+
+ let vs_entrypoint_vertex_param = vs_entrypoint_vertex_param
+ .type_layout()
+ .expect("Not possible");
+
+ if vs_entrypoint_vertex_param.parameter_category()
+ != ParameterCategory::VaryingInput
+ {
+ return Err(VertexDescriptionError::EntryPointVertexParamNotVaryingInput);
+ }
+
+ if vs_entrypoint_vertex_param.kind() != TypeKind::Struct {
+ return Err(VertexDescriptionError::EntrypointVertexTypeNotStruct);
+ }
+
+ let fields = vs_entrypoint_vertex_param
+ .fields()
+ .map(|field| {
+ let varying_input_offset =
+ field.varying_input_offset().expect("Not possible");
+
+ let field_ty = field.type_layout().expect("Maybe not possible");
+
+ let scalar_type = match field_ty.kind() {
+ TypeKind::Scalar => field_ty.scalar_type().expect("Not possible"),
+ TypeKind::Vector => {
+ let Some(scalar_type) = field_ty.scalar_type() else {
+ return Err(
+ VertexDescriptionError::UnsupportedVertexFieldType {
+ field_name: field.name().unwrap_or("").to_string(),
+ },
+ );
+ };
+
+ scalar_type
+ }
+ _ => {
+ return Err(VertexDescriptionError::UnsupportedVertexFieldType {
+ field_name: field.name().unwrap_or("").to_string(),
+ });
+ }
+ };
+
+ Ok(VertexFieldDescription {
+ name: field.name().unwrap_or("").to_string().into_boxed_str(),
+ varying_input_offset,
+ type_kind: field_ty.kind(),
+ scalar_type,
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Ok(Self { fields: fields.into_boxed_slice() })
+ }
+}
+
+#[derive(Debug)]
+pub struct VertexFieldDescription
+{
+ pub name: Box<str>,
+ pub varying_input_offset: usize,
+ pub type_kind: TypeKind,
+ pub scalar_type: ScalarType,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum VertexDescriptionError
+{
+ #[error("Entrypoint is not in vertex stage")]
+ EntrypointNotInVertexStage,
+
+ #[error(
+ "Entrypoint does not have a vertex parameter (parameter with semantic name {})",
+ VERTEX_PARAM_SEMANTIC_NAME
+ )]
+ EntrypointMissingVertexParam,
+
+ #[error("Entrypoint vertex parameter is not a varying input")]
+ EntryPointVertexParamNotVaryingInput,
+
+ #[error("Entrypoint vertex type is not a struct")]
+ EntrypointVertexTypeNotStruct,
+
+ #[error("Type of field '{field_name}' of vertex type is not supported")]
+ UnsupportedVertexFieldType
+ {
+ field_name: String
+ },
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub struct Error(#[from] shader_slang::Error);
+
+pub(crate) fn add_asset_importers(assets: &mut Assets)
+{
+ assets.set_importer::<_, _>(["slang"], import_slang_asset);
+}
+
+fn import_slang_asset(
+ asset_submitter: &mut AssetSubmitter<'_>,
+ file_path: &Path,
+ settings: Option<&'_ Settings>,
+) -> Result<(), ImportError>
+{
+ let file_name = file_path
+ .file_name()
+ .ok_or(ImportError::NoPathFileName)?
+ .to_str()
+ .ok_or(ImportError::PathFileNameNotUtf8)?;
+
+ let file_path_canonicalized = file_path
+ .canonicalize()
+ .map_err(ImportError::CanonicalizePathFailed)?;
+
+ asset_submitter.submit_store(ModuleSource {
+ name: file_name.to_owned().into(),
+ file_path: file_path_canonicalized.into(),
+ source: std::fs::read_to_string(file_path)
+ .map_err(ImportError::ReadFileFailed)?
+ .into(),
+ link_entrypoints: settings
+ .map(|settings| settings.link_entrypoints)
+ .unwrap_or_default(),
+ });
+
+ Ok(())
+}
+
+#[derive(Debug, thiserror::Error)]
+enum ImportError
+{
+ #[error("Failed to read file")]
+ ReadFileFailed(#[source] std::io::Error),
+
+ #[error("Asset path does not have a file name")]
+ NoPathFileName,
+
+ #[error("Asset path file name is not valid UTF8")]
+ PathFileNameNotUtf8,
+
+ #[error("Failed to canonicalize asset path")]
+ CanonicalizePathFailed(#[source] std::io::Error),
+}
+
+declare_entity!(
+ IMPORT_SHADERS_PHASE,
+ (
+ Phase,
+ Pair::builder()
+ .relation::<ChildOf>()
+ .target_id(*HANDLE_ASSETS_PHASE)
+ .build()
+ )
+);
+
+pub(crate) struct Extension;
+
+impl ecs::extension::Extension for Extension
+{
+ fn collect(self, mut collector: ecs::extension::Collector<'_>)
+ {
+ let Some(global_session) = SlangGlobalSession::new() else {
+ tracing::error!("Unable to create global shader-slang session");
+ return;
+ };
+
+ let session_options = shader_slang::CompilerOptions::default()
+ .optimization(shader_slang::OptimizationLevel::None)
+ .matrix_layout_column(true)
+ .debug_information(SlangDebugInfoLevel::Maximal)
+ .no_mangle(true);
+
+ let target_desc = shader_slang::TargetDesc::default()
+ .format(shader_slang::CompileTarget::Glsl)
+ // .format(shader_slang::CompileTarget::Spirv)
+ .profile(global_session.find_profile("glsl_330"));
+ // .profile(global_session.find_profile("spirv_1_5"));
+
+ let targets = [target_desc];
+
+ let session_desc = shader_slang::SessionDesc::default()
+ .targets(&targets)
+ .search_paths(&[""])
+ .options(&session_options);
+
+ let Some(session) = global_session.create_session(&session_desc) else {
+ tracing::error!("Failed to create shader-slang session");
+ return;
+ };
+
+ collector
+ .add_sole(Context {
+ _global_session: global_session,
+ session,
+ modules: HashMap::new(),
+ programs: HashMap::new(),
+ })
+ .ok();
+
+ collector.add_declared_entity(&IMPORT_SHADERS_PHASE);
+
+ collector.add_system(*START_PHASE, initialize);
+ collector.add_system(*IMPORT_SHADERS_PHASE, load_modules);
+
+ collector.add_system(
+ *POST_UPDATE_PHASE,
+ default_shader_enqueue_set_shader_bindings,
+ );
+ }
+}
+
+fn initialize(mut assets: Single<Assets>)
+{
+ assets.store_with_label(
+ ASSET_LABEL.clone(),
+ ModuleSource {
+ name: "default_shader.slang".into(),
+ file_path: Path::new("@engine/default_shader").into(),
+ source: include_str!("../res/default_shader.slang").into(),
+ link_entrypoints: EntrypointFlags::VERTEX | EntrypointFlags::FRAGMENT,
+ },
+ );
+}
+
+#[tracing::instrument(skip_all)]
+fn load_modules(mut context: Single<Context>, assets: Single<Assets>)
+{
+ for AssetEvent::Stored(asset_id, asset_label) in assets.events().last_tick_events() {
+ let asset_handle = AssetHandle::<ModuleSource>::from_id(*asset_id);
+
+ if !assets.is_loaded_and_has_type(&asset_handle) {
+ continue;
+ }
+
+ let Some(module_source) = assets.get(&asset_handle) else {
+ unreachable!();
+ };
+
+ tracing::debug!(asset_label=?asset_label, "Loading shader module");
+
+ let module = match load_module(&context.session, module_source) {
+ Ok(module) => module,
+ Err(err) => {
+ tracing::error!("Failed to load shader module: {err}");
+ continue;
+ }
+ };
+
+ context.modules.insert(*asset_id, module.clone());
+
+ if !module_source.link_entrypoints.is_empty() {
+ assert!(context.programs.get(asset_id).is_none());
+
+ let entry_points = match module_source
+ .link_entrypoints
+ .iter()
+ .filter_map(|entrypoint_flag| {
+ let entrypoint_name = bitflags_match!(entrypoint_flag, {
+ EntrypointFlags::VERTEX => Some("vertex_main"),
+ EntrypointFlags::FRAGMENT => Some("fragment_main"),
+ _ => None
+ })?;
+
+ let Some(entry_point) = module.get_entry_point(entrypoint_name)
+ else {
+ return Some(Err(EntrypointNotFoundError { entrypoint_name }));
+ };
+
+ Some(Ok(entry_point))
+ })
+ .collect::<Result<Vec<_>, EntrypointNotFoundError>>()
+ {
+ Ok(entry_points) => entry_points,
+ Err(EntrypointNotFoundError { entrypoint_name }) => {
+ tracing::error!(
+ "Shader module does not have a '{entrypoint_name}' entry point"
+ );
+ continue;
+ }
+ };
+
+ let shader_program =
+ match context.compose_into_program([module], entry_points) {
+ Ok(shader_program) => shader_program,
+ Err(err) => {
+ tracing::error!("Failed to compose shader into program: {err}");
+ continue;
+ }
+ };
+
+ let linked_shader_program = match shader_program.link() {
+ Ok(linked_shader_program) => linked_shader_program,
+ Err(err) => {
+ tracing::error!("Failed to link shader: {err}");
+ continue;
+ }
+ };
+
+ let vertex_desc = if module_source
+ .link_entrypoints
+ .contains(EntrypointFlags::VERTEX)
+ {
+ VertexDescription::new(
+ &shader_program
+ .reflection(0)
+ .expect("Not possible")
+ .get_entry_point_by_name("vertex_main")
+ .expect("Not possible"),
+ )
+ .inspect_err(|err| {
+ tracing::error!(
+ "Failed to create a vertex description for shader {}: {}",
+ asset_label,
+ err
+ );
+ })
+ .ok()
+ } else {
+ None
+ };
+
+ context.programs.insert(
+ *asset_id,
+ (linked_shader_program, ProgramMetadata { vertex_desc }),
+ );
+ }
+ }
+}
+
+fn load_module(
+ session: &SlangSession,
+ module_source: &ModuleSource,
+) -> Result<Module, Error>
+{
+ let module = session.load_module_from_source_string(
+ &module_source.name,
+ &module_source.file_path.to_string_lossy(),
+ &module_source.source,
+ )?;
+
+ Ok(Module { inner: module })
+}
+
+#[derive(Debug)]
+struct EntrypointNotFoundError
+{
+ entrypoint_name: &'static str,
+}