summaryrefslogtreecommitdiff
path: root/engine/src/shader.rs
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2026-03-20 14:22:19 +0100
committerHampusM <hampus@hampusmat.com>2026-03-20 14:22:19 +0100
commitf285f82072b491b1f3cc92db8e08485f26779d5a (patch)
treebf6c6c61cdfb3a12550e55966c8552957ade9e71 /engine/src/shader.rs
parent0546d575c11d3668d0f95933697ae4f670fe2a55 (diff)
feat(engine): use slang for shadersHEADmaster
Diffstat (limited to 'engine/src/shader.rs')
-rw-r--r--engine/src/shader.rs818
1 files changed, 818 insertions, 0 deletions
diff --git a/engine/src/shader.rs b/engine/src/shader.rs
new file mode 100644
index 0000000..65872f1
--- /dev/null
+++ b/engine/src/shader.rs
@@ -0,0 +1,818 @@
+use std::any::type_name;
+use std::borrow::Cow;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::path::Path;
+use std::str::Utf8Error;
+
+use bitflags::bitflags;
+use ecs::pair::{ChildOf, Pair};
+use ecs::phase::{Phase, START as START_PHASE};
+use ecs::sole::Single;
+use ecs::{Component, Sole, declare_entity};
+use shader_slang::{
+ Blob as SlangBlob,
+ ComponentType as SlangComponentType,
+ DebugInfoLevel as SlangDebugInfoLevel,
+ EntryPoint as SlangEntryPoint,
+ GlobalSession as SlangGlobalSession,
+ Module as SlangModule,
+ ParameterCategory as SlangParameterCategory,
+ Session as SlangSession,
+ TypeKind as SlangTypeKind,
+};
+
+use crate::asset::{
+ Assets,
+ Event as AssetEvent,
+ HANDLE_ASSETS_PHASE,
+ Handle as AssetHandle,
+ Id as AssetId,
+ Submitter as AssetSubmitter,
+};
+use crate::builder;
+use crate::renderer::PRE_RENDER_PHASE;
+use crate::shader::default::{
+ ASSET_LABEL,
+ enqueue_set_shader_bindings as default_shader_enqueue_set_shader_bindings,
+};
+
+pub mod cursor;
+pub mod default;
+
+#[derive(Debug, Clone, Component)]
+pub struct Shader
+{
+ pub asset_handle: AssetHandle<ModuleSource>,
+}
+
+/// Shader module.
+#[derive(Debug)]
+pub struct ModuleSource
+{
+ pub name: Cow<'static, str>,
+ pub file_path: Cow<'static, Path>,
+ pub source: Cow<'static, str>,
+ pub link_entrypoints: EntrypointFlags,
+}
+
+bitflags! {
+ #[derive(Debug, Clone, Copy, Default)]
+ pub struct EntrypointFlags: usize
+ {
+ const FRAGMENT = 1 << 0;
+ const VERTEX = 1 << 1;
+ }
+}
+
+pub struct Module
+{
+ inner: SlangModule,
+}
+
+impl Module
+{
+ pub fn entry_points(&self) -> impl ExactSizeIterator<Item = EntryPoint>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPoint { inner: entry_point })
+ }
+
+ pub fn get_entry_point(&self, entry_point: &str) -> Option<EntryPoint>
+ {
+ let entry_point = self.inner.find_entry_point_by_name(entry_point)?;
+
+ Some(EntryPoint { inner: entry_point })
+ }
+}
+
+pub struct EntryPoint
+{
+ inner: SlangEntryPoint,
+}
+
+impl EntryPoint
+{
+ pub fn function_name(&self) -> Option<&str>
+ {
+ self.inner.function_reflection().name()
+ }
+}
+
+pub struct EntryPointReflection<'a>
+{
+ inner: &'a shader_slang::reflection::EntryPoint,
+}
+
+impl<'a> EntryPointReflection<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+
+ pub fn name_override(&self) -> Option<&str>
+ {
+ self.inner.name_override()
+ }
+
+ pub fn stage(&self) -> Stage
+ {
+ Stage::from_slang_stage(self.inner.stage())
+ }
+
+ pub fn parameters(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .parameters()
+ .map(|param| VariableLayout { inner: param })
+ }
+
+ pub fn var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout { inner: self.inner.var_layout()? })
+ }
+}
+
+#[derive(Clone)]
+pub struct Program
+{
+ inner: SlangComponentType,
+}
+
+impl Program
+{
+ pub fn link(&self) -> Result<Program, Error>
+ {
+ let linked_program = self.inner.link()?;
+
+ Ok(Program { inner: linked_program })
+ }
+
+ pub fn get_entry_point_code(&self, entry_point_index: u32) -> Result<Blob, Error>
+ {
+ let blob = self.inner.entry_point_code(entry_point_index.into(), 0)?;
+
+ Ok(Blob { inner: blob })
+ }
+
+ pub fn reflection(&self, target: u32) -> Result<ProgramReflection<'_>, Error>
+ {
+ let reflection = self.inner.layout(target as i64)?;
+
+ Ok(ProgramReflection { inner: reflection })
+ }
+}
+
+impl Debug for Program
+{
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
+ {
+ formatter
+ .debug_struct(type_name::<Self>())
+ .finish_non_exhaustive()
+ }
+}
+
+pub struct ProgramReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Shader,
+}
+
+impl<'a> ProgramReflection<'a>
+{
+ pub fn get_entry_point_by_index(&self, index: u32)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.entry_point_by_index(index)?,
+ })
+ }
+
+ pub fn get_entry_point_by_name(&self, name: &str)
+ -> Option<EntryPointReflection<'a>>
+ {
+ Some(EntryPointReflection {
+ inner: self.inner.find_entry_point_by_name(name)?,
+ })
+ }
+
+ pub fn entry_points(&self)
+ -> impl ExactSizeIterator<Item = EntryPointReflection<'a>>
+ {
+ self.inner
+ .entry_points()
+ .map(|entry_point| EntryPointReflection { inner: entry_point })
+ }
+
+ pub fn global_params_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self.inner.global_params_type_layout()?,
+ })
+ }
+
+ pub fn global_params_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ Some(VariableLayout {
+ inner: self.inner.global_params_var_layout()?,
+ })
+ }
+
+ pub fn get_type(&self, name: &str) -> Option<TypeReflection<'a>>
+ {
+ Some(TypeReflection {
+ inner: self.inner.find_type_by_name(name)?,
+ })
+ }
+
+ pub fn get_type_layout(&self, ty: &TypeReflection<'a>) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout {
+ inner: self
+ .inner
+ .type_layout(&ty.inner, shader_slang::LayoutRules::Default)?,
+ })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct VariableLayout<'a>
+{
+ inner: &'a shader_slang::reflection::VariableLayout,
+}
+
+impl<'a> VariableLayout<'a>
+{
+ pub fn name(&self) -> Option<&str>
+ {
+ self.inner.name()
+ }
+
+ pub fn semantic_name(&self) -> Option<&str>
+ {
+ self.inner.semantic_name()
+ }
+
+ pub fn binding_index(&self) -> u32
+ {
+ self.inner
+ .offset(shader_slang::ParameterCategory::DescriptorTableSlot) as u32
+
+ // self.inner.binding_index()
+ }
+
+ pub fn binding_space(&self) -> u32
+ {
+ self.inner.binding_space()
+ }
+
+ pub fn semantic_index(&self) -> usize
+ {
+ self.inner.semantic_index()
+ }
+
+ pub fn offset(&self) -> usize
+ {
+ self.inner.offset(shader_slang::ParameterCategory::Uniform)
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ Some(TypeLayout { inner: self.inner.type_layout()? })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct TypeLayout<'a>
+{
+ inner: &'a shader_slang::reflection::TypeLayout,
+}
+
+impl<'a> TypeLayout<'a>
+{
+ pub fn get_field_by_name(&self, name: &str) -> Option<VariableLayout<'a>>
+ {
+ let index = self.inner.find_field_index_by_name(name);
+
+ if index < 0 {
+ return None;
+ }
+
+ let index = u32::try_from(index.cast_unsigned()).expect("Should not happend");
+
+ let field = self.inner.field_by_index(index)?;
+
+ Some(VariableLayout { inner: field })
+ }
+
+ pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64
+ {
+ self.inner.binding_range_descriptor_set_index(index)
+ }
+
+ pub fn get_field_binding_range_offset_by_name(&self, name: &str) -> Option<u64>
+ {
+ let field_index = self.inner.find_field_index_by_name(name);
+
+ if field_index < 0 {
+ return None;
+ }
+
+ let field_binding_range_offset =
+ self.inner.field_binding_range_offset(field_index);
+
+ if field_binding_range_offset < 0 {
+ return None;
+ }
+
+ Some(field_binding_range_offset.cast_unsigned())
+ }
+
+ pub fn ty(&self) -> Option<TypeReflection<'a>>
+ {
+ self.inner.ty().map(|ty| TypeReflection { inner: ty })
+ }
+
+ pub fn fields(&self) -> impl ExactSizeIterator<Item = VariableLayout<'a>>
+ {
+ self.inner
+ .fields()
+ .map(|field| VariableLayout { inner: field })
+ }
+
+ pub fn field_cnt(&self) -> u32
+ {
+ self.inner.field_count()
+ }
+
+ pub fn element_type_layout(&self) -> Option<TypeLayout<'a>>
+ {
+ self.inner
+ .element_type_layout()
+ .map(|type_layout| TypeLayout { inner: type_layout })
+ }
+
+ pub fn element_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .element_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn container_var_layout(&self) -> Option<VariableLayout<'a>>
+ {
+ self.inner
+ .container_var_layout()
+ .map(|var_layout| VariableLayout { inner: var_layout })
+ }
+
+ pub fn uniform_size(&self) -> Option<usize>
+ {
+ // tracing::debug!(
+ // "uniform_size: {:?} categories: {:?}",
+ // self.inner.name(),
+ // self.inner.categories().collect::<Vec<_>>(),
+ // );
+
+ if !self
+ .inner
+ .categories()
+ .any(|category| category == SlangParameterCategory::Uniform)
+ {
+ return None;
+ }
+
+ // let category = self.inner.categories().next().unwrap();
+
+ // println!(
+ // "AARGH size Category: {category:?} Category count: {}",
+ // self.inner.category_count()
+ // );
+
+ // Some(self.inner.size(category))
+
+ Some(self.inner.size(SlangParameterCategory::Uniform))
+ }
+
+ pub fn stride(&self) -> usize
+ {
+ self.inner.stride(self.inner.categories().next().unwrap())
+ }
+}
+
+pub struct TypeReflection<'a>
+{
+ inner: &'a shader_slang::reflection::Type,
+}
+
+impl TypeReflection<'_>
+{
+ pub fn kind(&self) -> TypeKind
+ {
+ TypeKind::from_slang_type_kind(self.inner.kind())
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum TypeKind
+{
+ None,
+ Struct,
+ Enum,
+ Array,
+ Matrix,
+ Vector,
+ Scalar,
+ ConstantBuffer,
+ Resource,
+ SamplerState,
+ TextureBuffer,
+ ShaderStorageBuffer,
+ ParameterBlock,
+ GenericTypeParameter,
+ Interface,
+ OutputStream,
+ MeshOutput,
+ Specialized,
+ Feedback,
+ Pointer,
+ DynamicResource,
+ Count,
+}
+
+impl TypeKind
+{
+ fn from_slang_type_kind(type_kind: SlangTypeKind) -> Self
+ {
+ match type_kind {
+ SlangTypeKind::None => Self::None,
+ SlangTypeKind::Struct => Self::Struct,
+ SlangTypeKind::Enum => Self::Enum,
+ SlangTypeKind::Array => Self::Array,
+ SlangTypeKind::Matrix => Self::Matrix,
+ SlangTypeKind::Vector => Self::Vector,
+ SlangTypeKind::Scalar => Self::Scalar,
+ SlangTypeKind::ConstantBuffer => Self::ConstantBuffer,
+ SlangTypeKind::Resource => Self::Resource,
+ SlangTypeKind::SamplerState => Self::SamplerState,
+ SlangTypeKind::TextureBuffer => Self::TextureBuffer,
+ SlangTypeKind::ShaderStorageBuffer => Self::ShaderStorageBuffer,
+ SlangTypeKind::ParameterBlock => Self::ParameterBlock,
+ SlangTypeKind::GenericTypeParameter => Self::GenericTypeParameter,
+ SlangTypeKind::Interface => Self::Interface,
+ SlangTypeKind::OutputStream => Self::OutputStream,
+ SlangTypeKind::MeshOutput => Self::MeshOutput,
+ SlangTypeKind::Specialized => Self::Specialized,
+ SlangTypeKind::Feedback => Self::Feedback,
+ SlangTypeKind::Pointer => Self::Pointer,
+ SlangTypeKind::DynamicResource => Self::DynamicResource,
+ SlangTypeKind::Count => Self::Count,
+ }
+ }
+}
+
+pub struct Blob
+{
+ inner: SlangBlob,
+}
+
+impl Blob
+{
+ pub fn as_bytes(&self) -> &[u8]
+ {
+ self.inner.as_slice()
+ }
+
+ pub fn as_str(&self) -> Result<&str, Utf8Error>
+ {
+ self.inner.as_str()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[non_exhaustive]
+pub enum Stage
+{
+ None,
+ Vertex,
+ Hull,
+ Domain,
+ Geometry,
+ Fragment,
+ Compute,
+ RayGeneration,
+ Intersection,
+ AnyHit,
+ ClosestHit,
+ Miss,
+ Callable,
+ Mesh,
+ Amplification,
+ Dispatch,
+ Count,
+}
+
+impl Stage
+{
+ fn from_slang_stage(stage: shader_slang::Stage) -> Self
+ {
+ match stage {
+ shader_slang::Stage::None => Self::None,
+ shader_slang::Stage::Vertex => Self::Vertex,
+ shader_slang::Stage::Hull => Self::Hull,
+ shader_slang::Stage::Domain => Self::Domain,
+ shader_slang::Stage::Geometry => Self::Geometry,
+ shader_slang::Stage::Fragment => Self::Fragment,
+ shader_slang::Stage::Compute => Self::Compute,
+ shader_slang::Stage::RayGeneration => Self::RayGeneration,
+ shader_slang::Stage::Intersection => Self::Intersection,
+ shader_slang::Stage::AnyHit => Self::AnyHit,
+ shader_slang::Stage::ClosestHit => Self::ClosestHit,
+ shader_slang::Stage::Miss => Self::Miss,
+ shader_slang::Stage::Callable => Self::Callable,
+ shader_slang::Stage::Mesh => Self::Mesh,
+ shader_slang::Stage::Amplification => Self::Amplification,
+ shader_slang::Stage::Dispatch => Self::Dispatch,
+ shader_slang::Stage::Count => Self::Count,
+ }
+ }
+}
+
+builder! {
+#[builder(name = SettingsBuilder, derives=(Debug))]
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct Settings
+{
+ link_entrypoints: EntrypointFlags,
+}
+}
+
+#[derive(Sole)]
+pub struct Context
+{
+ _global_session: SlangGlobalSession,
+ session: SlangSession,
+ modules: HashMap<AssetId, Module>,
+ programs: HashMap<AssetId, Program>,
+}
+
+impl Context
+{
+ pub fn get_module(&self, asset_id: &AssetId) -> Option<&Module>
+ {
+ self.modules.get(asset_id)
+ }
+
+ pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program>
+ {
+ self.programs.get(asset_id)
+ }
+
+ pub fn compose_into_program(
+ &self,
+ modules: &[&Module],
+ entry_points: &[&EntryPoint],
+ ) -> Result<Program, Error>
+ {
+ let components =
+ modules
+ .iter()
+ .map(|module| SlangComponentType::from(module.inner.clone()))
+ .chain(entry_points.iter().map(|entry_point| {
+ SlangComponentType::from(entry_point.inner.clone())
+ }))
+ .collect::<Vec<_>>();
+
+ let program = self.session.create_composite_component_type(&components)?;
+
+ Ok(Program { inner: program })
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub struct Error(#[from] shader_slang::Error);
+
+pub(crate) fn add_asset_importers(assets: &mut Assets)
+{
+ assets.set_importer::<_, _>(["slang"], import_slang_asset);
+}
+
+fn import_slang_asset(
+ asset_submitter: &mut AssetSubmitter<'_>,
+ file_path: &Path,
+ settings: Option<&'_ Settings>,
+) -> Result<(), ImportError>
+{
+ let file_name = file_path
+ .file_name()
+ .ok_or(ImportError::NoPathFileName)?
+ .to_str()
+ .ok_or(ImportError::PathFileNameNotUtf8)?;
+
+ let file_path_canonicalized = file_path
+ .canonicalize()
+ .map_err(ImportError::CanonicalizePathFailed)?;
+
+ asset_submitter.submit_store(ModuleSource {
+ name: file_name.to_owned().into(),
+ file_path: file_path_canonicalized.into(),
+ source: std::fs::read_to_string(file_path)
+ .map_err(ImportError::ReadFileFailed)?
+ .into(),
+ link_entrypoints: settings
+ .map(|settings| settings.link_entrypoints)
+ .unwrap_or_default(),
+ });
+
+ Ok(())
+}
+
+#[derive(Debug, thiserror::Error)]
+enum ImportError
+{
+ #[error("Failed to read file")]
+ ReadFileFailed(#[source] std::io::Error),
+
+ #[error("Asset path does not have a file name")]
+ NoPathFileName,
+
+ #[error("Asset path file name is not valid UTF8")]
+ PathFileNameNotUtf8,
+
+ #[error("Failed to canonicalize asset path")]
+ CanonicalizePathFailed(#[source] std::io::Error),
+}
+
+declare_entity!(
+ IMPORT_SHADERS_PHASE,
+ (
+ Phase,
+ Pair::builder()
+ .relation::<ChildOf>()
+ .target_id(*HANDLE_ASSETS_PHASE)
+ .build()
+ )
+);
+
+pub(crate) struct Extension;
+
+impl ecs::extension::Extension for Extension
+{
+ fn collect(self, mut collector: ecs::extension::Collector<'_>)
+ {
+ let Some(global_session) = SlangGlobalSession::new() else {
+ tracing::error!("Unable to create global shader-slang session");
+ return;
+ };
+
+ let session_options = shader_slang::CompilerOptions::default()
+ .optimization(shader_slang::OptimizationLevel::None)
+ .matrix_layout_column(true)
+ .debug_information(SlangDebugInfoLevel::Maximal)
+ .no_mangle(true);
+
+ let target_desc = shader_slang::TargetDesc::default()
+ .format(shader_slang::CompileTarget::Glsl)
+ // .format(shader_slang::CompileTarget::Spirv)
+ .profile(global_session.find_profile("glsl_330"));
+ // .profile(global_session.find_profile("spirv_1_5"));
+
+ let targets = [target_desc];
+
+ let session_desc = shader_slang::SessionDesc::default()
+ .targets(&targets)
+ .search_paths(&[""])
+ .options(&session_options);
+
+ let Some(session) = global_session.create_session(&session_desc) else {
+ tracing::error!("Failed to create shader-slang session");
+ return;
+ };
+
+ collector
+ .add_sole(Context {
+ _global_session: global_session,
+ session,
+ modules: HashMap::new(),
+ programs: HashMap::new(),
+ })
+ .ok();
+
+ collector.add_declared_entity(&IMPORT_SHADERS_PHASE);
+
+ collector.add_system(*START_PHASE, initialize);
+ collector.add_system(*IMPORT_SHADERS_PHASE, load_modules);
+
+ collector.add_system(
+ *PRE_RENDER_PHASE,
+ default_shader_enqueue_set_shader_bindings,
+ );
+ }
+}
+
+fn initialize(mut assets: Single<Assets>)
+{
+ assets.store_with_label(
+ ASSET_LABEL.clone(),
+ ModuleSource {
+ name: "default_shader.slang".into(),
+ file_path: Path::new("@engine/default_shader").into(),
+ source: include_str!("../res/default_shader.slang").into(),
+ link_entrypoints: EntrypointFlags::VERTEX | EntrypointFlags::FRAGMENT,
+ },
+ );
+}
+
+#[tracing::instrument(skip_all)]
+fn load_modules(mut context: Single<Context>, assets: Single<Assets>)
+{
+ for AssetEvent::Stored(asset_id, asset_label) in assets.events().last_tick_events() {
+ let asset_handle = AssetHandle::<ModuleSource>::from_id(*asset_id);
+
+ if !assets.is_loaded_and_has_type(&asset_handle) {
+ continue;
+ }
+
+ let Some(module_source) = assets.get(&asset_handle) else {
+ unreachable!();
+ };
+
+ tracing::debug!(asset_label=?asset_label, "Loading shader module");
+
+ let module = match load_module(&context.session, module_source) {
+ Ok(module) => module,
+ Err(err) => {
+ tracing::error!("Failed to load shader module: {err}");
+ continue;
+ }
+ };
+
+ if !module_source.link_entrypoints.is_empty() {
+ assert!(context.programs.get(asset_id).is_none());
+
+ let Some(vertex_shader_entry_point) = module.get_entry_point("vertex_main")
+ else {
+ tracing::error!(
+ "Shader module does not contain a vertex shader entry point"
+ );
+ continue;
+ };
+
+ let Some(fragment_shader_entry_point) =
+ module.get_entry_point("fragment_main")
+ else {
+ tracing::error!(
+ "Shader module don't contain a fragment_main entry point"
+ );
+ continue;
+ };
+
+ let shader_program = match context.compose_into_program(
+ &[&module],
+ &[&vertex_shader_entry_point, &fragment_shader_entry_point],
+ ) {
+ Ok(shader_program) => shader_program,
+ Err(err) => {
+ tracing::error!("Failed to compose shader into program: {err}");
+ continue;
+ }
+ };
+
+ let linked_shader_program = match shader_program.link() {
+ Ok(linked_shader_program) => linked_shader_program,
+ Err(err) => {
+ tracing::error!("Failed to link shader: {err}");
+ continue;
+ }
+ };
+
+ context.programs.insert(*asset_id, linked_shader_program);
+ }
+
+ context.modules.insert(*asset_id, module);
+ }
+}
+
+fn load_module(
+ session: &SlangSession,
+ module_source: &ModuleSource,
+) -> Result<Module, Error>
+{
+ let module = session.load_module_from_source_string(
+ &module_source.name,
+ &module_source.file_path.to_string_lossy(),
+ &module_source.source,
+ )?;
+
+ Ok(Module { inner: module })
+}