use std::any::type_name; use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Debug; use std::path::Path; use std::str::Utf8Error; use bitflags::bitflags; use ecs::pair::{ChildOf, Pair}; use ecs::phase::{Phase, START as START_PHASE}; use ecs::sole::Single; use ecs::{Component, Sole, declare_entity}; use shader_slang::{ Blob as SlangBlob, ComponentType as SlangComponentType, DebugInfoLevel as SlangDebugInfoLevel, EntryPoint as SlangEntryPoint, GlobalSession as SlangGlobalSession, Module as SlangModule, ParameterCategory as SlangParameterCategory, Session as SlangSession, TypeKind as SlangTypeKind, }; use crate::asset::{ Assets, Event as AssetEvent, HANDLE_ASSETS_PHASE, Handle as AssetHandle, Id as AssetId, Submitter as AssetSubmitter, }; use crate::builder; use crate::renderer::PRE_RENDER_PHASE; use crate::shader::default::{ ASSET_LABEL, enqueue_set_shader_bindings as default_shader_enqueue_set_shader_bindings, }; pub mod cursor; pub mod default; #[derive(Debug, Clone, Component)] pub struct Shader { pub asset_handle: AssetHandle, } /// Shader module. #[derive(Debug)] pub struct ModuleSource { pub name: Cow<'static, str>, pub file_path: Cow<'static, Path>, pub source: Cow<'static, str>, pub link_entrypoints: EntrypointFlags, } bitflags! { #[derive(Debug, Clone, Copy, Default)] pub struct EntrypointFlags: usize { const FRAGMENT = 1 << 0; const VERTEX = 1 << 1; } } pub struct Module { inner: SlangModule, } impl Module { pub fn entry_points(&self) -> impl ExactSizeIterator { self.inner .entry_points() .map(|entry_point| EntryPoint { inner: entry_point }) } pub fn get_entry_point(&self, entry_point: &str) -> Option { let entry_point = self.inner.find_entry_point_by_name(entry_point)?; Some(EntryPoint { inner: entry_point }) } } pub struct EntryPoint { inner: SlangEntryPoint, } impl EntryPoint { pub fn function_name(&self) -> Option<&str> { self.inner.function_reflection().name() } } pub struct EntryPointReflection<'a> { inner: &'a shader_slang::reflection::EntryPoint, } impl<'a> EntryPointReflection<'a> { pub fn name(&self) -> Option<&str> { self.inner.name() } pub fn name_override(&self) -> Option<&str> { self.inner.name_override() } pub fn stage(&self) -> Stage { Stage::from_slang_stage(self.inner.stage()) } pub fn parameters(&self) -> impl ExactSizeIterator> { self.inner .parameters() .map(|param| VariableLayout { inner: param }) } pub fn var_layout(&self) -> Option> { Some(VariableLayout { inner: self.inner.var_layout()? }) } } #[derive(Clone)] pub struct Program { inner: SlangComponentType, } impl Program { pub fn link(&self) -> Result { let linked_program = self.inner.link()?; Ok(Program { inner: linked_program }) } pub fn get_entry_point_code(&self, entry_point_index: u32) -> Result { let blob = self.inner.entry_point_code(entry_point_index.into(), 0)?; Ok(Blob { inner: blob }) } pub fn reflection(&self, target: u32) -> Result, Error> { let reflection = self.inner.layout(target as i64)?; Ok(ProgramReflection { inner: reflection }) } } impl Debug for Program { fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { formatter .debug_struct(type_name::()) .finish_non_exhaustive() } } pub struct ProgramReflection<'a> { inner: &'a shader_slang::reflection::Shader, } impl<'a> ProgramReflection<'a> { pub fn get_entry_point_by_index(&self, index: u32) -> Option> { Some(EntryPointReflection { inner: self.inner.entry_point_by_index(index)?, }) } pub fn get_entry_point_by_name(&self, name: &str) -> Option> { Some(EntryPointReflection { inner: self.inner.find_entry_point_by_name(name)?, }) } pub fn entry_points(&self) -> impl ExactSizeIterator> { self.inner .entry_points() .map(|entry_point| EntryPointReflection { inner: entry_point }) } pub fn global_params_type_layout(&self) -> Option> { Some(TypeLayout { inner: self.inner.global_params_type_layout()?, }) } pub fn global_params_var_layout(&self) -> Option> { Some(VariableLayout { inner: self.inner.global_params_var_layout()?, }) } pub fn get_type(&self, name: &str) -> Option> { Some(TypeReflection { inner: self.inner.find_type_by_name(name)?, }) } pub fn get_type_layout(&self, ty: &TypeReflection<'a>) -> Option> { Some(TypeLayout { inner: self .inner .type_layout(&ty.inner, shader_slang::LayoutRules::Default)?, }) } } #[derive(Clone, Copy)] pub struct VariableLayout<'a> { inner: &'a shader_slang::reflection::VariableLayout, } impl<'a> VariableLayout<'a> { pub fn name(&self) -> Option<&str> { self.inner.name() } pub fn semantic_name(&self) -> Option<&str> { self.inner.semantic_name() } pub fn binding_index(&self) -> u32 { self.inner .offset(shader_slang::ParameterCategory::DescriptorTableSlot) as u32 // self.inner.binding_index() } pub fn binding_space(&self) -> u32 { self.inner.binding_space() } pub fn semantic_index(&self) -> usize { self.inner.semantic_index() } pub fn offset(&self) -> usize { self.inner.offset(shader_slang::ParameterCategory::Uniform) } pub fn ty(&self) -> Option> { self.inner.ty().map(|ty| TypeReflection { inner: ty }) } pub fn type_layout(&self) -> Option> { Some(TypeLayout { inner: self.inner.type_layout()? }) } } #[derive(Clone, Copy)] pub struct TypeLayout<'a> { inner: &'a shader_slang::reflection::TypeLayout, } impl<'a> TypeLayout<'a> { pub fn get_field_by_name(&self, name: &str) -> Option> { let index = self.inner.find_field_index_by_name(name); if index < 0 { return None; } let index = u32::try_from(index.cast_unsigned()).expect("Should not happend"); let field = self.inner.field_by_index(index)?; Some(VariableLayout { inner: field }) } pub fn binding_range_descriptor_set_index(&self, index: i64) -> i64 { self.inner.binding_range_descriptor_set_index(index) } pub fn get_field_binding_range_offset_by_name(&self, name: &str) -> Option { let field_index = self.inner.find_field_index_by_name(name); if field_index < 0 { return None; } let field_binding_range_offset = self.inner.field_binding_range_offset(field_index); if field_binding_range_offset < 0 { return None; } Some(field_binding_range_offset.cast_unsigned()) } pub fn ty(&self) -> Option> { self.inner.ty().map(|ty| TypeReflection { inner: ty }) } pub fn fields(&self) -> impl ExactSizeIterator> { self.inner .fields() .map(|field| VariableLayout { inner: field }) } pub fn field_cnt(&self) -> u32 { self.inner.field_count() } pub fn element_type_layout(&self) -> Option> { self.inner .element_type_layout() .map(|type_layout| TypeLayout { inner: type_layout }) } pub fn element_var_layout(&self) -> Option> { self.inner .element_var_layout() .map(|var_layout| VariableLayout { inner: var_layout }) } pub fn container_var_layout(&self) -> Option> { self.inner .container_var_layout() .map(|var_layout| VariableLayout { inner: var_layout }) } pub fn uniform_size(&self) -> Option { // tracing::debug!( // "uniform_size: {:?} categories: {:?}", // self.inner.name(), // self.inner.categories().collect::>(), // ); if !self .inner .categories() .any(|category| category == SlangParameterCategory::Uniform) { return None; } // let category = self.inner.categories().next().unwrap(); // println!( // "AARGH size Category: {category:?} Category count: {}", // self.inner.category_count() // ); // Some(self.inner.size(category)) Some(self.inner.size(SlangParameterCategory::Uniform)) } pub fn stride(&self) -> usize { self.inner.stride(self.inner.categories().next().unwrap()) } } pub struct TypeReflection<'a> { inner: &'a shader_slang::reflection::Type, } impl TypeReflection<'_> { pub fn kind(&self) -> TypeKind { TypeKind::from_slang_type_kind(self.inner.kind()) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[non_exhaustive] pub enum TypeKind { None, Struct, Enum, Array, Matrix, Vector, Scalar, ConstantBuffer, Resource, SamplerState, TextureBuffer, ShaderStorageBuffer, ParameterBlock, GenericTypeParameter, Interface, OutputStream, MeshOutput, Specialized, Feedback, Pointer, DynamicResource, Count, } impl TypeKind { fn from_slang_type_kind(type_kind: SlangTypeKind) -> Self { match type_kind { SlangTypeKind::None => Self::None, SlangTypeKind::Struct => Self::Struct, SlangTypeKind::Enum => Self::Enum, SlangTypeKind::Array => Self::Array, SlangTypeKind::Matrix => Self::Matrix, SlangTypeKind::Vector => Self::Vector, SlangTypeKind::Scalar => Self::Scalar, SlangTypeKind::ConstantBuffer => Self::ConstantBuffer, SlangTypeKind::Resource => Self::Resource, SlangTypeKind::SamplerState => Self::SamplerState, SlangTypeKind::TextureBuffer => Self::TextureBuffer, SlangTypeKind::ShaderStorageBuffer => Self::ShaderStorageBuffer, SlangTypeKind::ParameterBlock => Self::ParameterBlock, SlangTypeKind::GenericTypeParameter => Self::GenericTypeParameter, SlangTypeKind::Interface => Self::Interface, SlangTypeKind::OutputStream => Self::OutputStream, SlangTypeKind::MeshOutput => Self::MeshOutput, SlangTypeKind::Specialized => Self::Specialized, SlangTypeKind::Feedback => Self::Feedback, SlangTypeKind::Pointer => Self::Pointer, SlangTypeKind::DynamicResource => Self::DynamicResource, SlangTypeKind::Count => Self::Count, } } } pub struct Blob { inner: SlangBlob, } impl Blob { pub fn as_bytes(&self) -> &[u8] { self.inner.as_slice() } pub fn as_str(&self) -> Result<&str, Utf8Error> { self.inner.as_str() } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[non_exhaustive] pub enum Stage { None, Vertex, Hull, Domain, Geometry, Fragment, Compute, RayGeneration, Intersection, AnyHit, ClosestHit, Miss, Callable, Mesh, Amplification, Dispatch, Count, } impl Stage { fn from_slang_stage(stage: shader_slang::Stage) -> Self { match stage { shader_slang::Stage::None => Self::None, shader_slang::Stage::Vertex => Self::Vertex, shader_slang::Stage::Hull => Self::Hull, shader_slang::Stage::Domain => Self::Domain, shader_slang::Stage::Geometry => Self::Geometry, shader_slang::Stage::Fragment => Self::Fragment, shader_slang::Stage::Compute => Self::Compute, shader_slang::Stage::RayGeneration => Self::RayGeneration, shader_slang::Stage::Intersection => Self::Intersection, shader_slang::Stage::AnyHit => Self::AnyHit, shader_slang::Stage::ClosestHit => Self::ClosestHit, shader_slang::Stage::Miss => Self::Miss, shader_slang::Stage::Callable => Self::Callable, shader_slang::Stage::Mesh => Self::Mesh, shader_slang::Stage::Amplification => Self::Amplification, shader_slang::Stage::Dispatch => Self::Dispatch, shader_slang::Stage::Count => Self::Count, } } } builder! { #[builder(name = SettingsBuilder, derives=(Debug))] #[derive(Debug)] #[non_exhaustive] pub struct Settings { link_entrypoints: EntrypointFlags, } } #[derive(Sole)] pub struct Context { _global_session: SlangGlobalSession, session: SlangSession, modules: HashMap, programs: HashMap, } impl Context { pub fn get_module(&self, asset_id: &AssetId) -> Option<&Module> { self.modules.get(asset_id) } pub fn get_program(&self, asset_id: &AssetId) -> Option<&Program> { self.programs.get(asset_id) } pub fn compose_into_program( &self, modules: &[&Module], entry_points: &[&EntryPoint], ) -> Result { let components = modules .iter() .map(|module| SlangComponentType::from(module.inner.clone())) .chain(entry_points.iter().map(|entry_point| { SlangComponentType::from(entry_point.inner.clone()) })) .collect::>(); let program = self.session.create_composite_component_type(&components)?; Ok(Program { inner: program }) } } #[derive(Debug, thiserror::Error)] #[error(transparent)] pub struct Error(#[from] shader_slang::Error); pub(crate) fn add_asset_importers(assets: &mut Assets) { assets.set_importer::<_, _>(["slang"], import_slang_asset); } fn import_slang_asset( asset_submitter: &mut AssetSubmitter<'_>, file_path: &Path, settings: Option<&'_ Settings>, ) -> Result<(), ImportError> { let file_name = file_path .file_name() .ok_or(ImportError::NoPathFileName)? .to_str() .ok_or(ImportError::PathFileNameNotUtf8)?; let file_path_canonicalized = file_path .canonicalize() .map_err(ImportError::CanonicalizePathFailed)?; asset_submitter.submit_store(ModuleSource { name: file_name.to_owned().into(), file_path: file_path_canonicalized.into(), source: std::fs::read_to_string(file_path) .map_err(ImportError::ReadFileFailed)? .into(), link_entrypoints: settings .map(|settings| settings.link_entrypoints) .unwrap_or_default(), }); Ok(()) } #[derive(Debug, thiserror::Error)] enum ImportError { #[error("Failed to read file")] ReadFileFailed(#[source] std::io::Error), #[error("Asset path does not have a file name")] NoPathFileName, #[error("Asset path file name is not valid UTF8")] PathFileNameNotUtf8, #[error("Failed to canonicalize asset path")] CanonicalizePathFailed(#[source] std::io::Error), } declare_entity!( IMPORT_SHADERS_PHASE, ( Phase, Pair::builder() .relation::() .target_id(*HANDLE_ASSETS_PHASE) .build() ) ); pub(crate) struct Extension; impl ecs::extension::Extension for Extension { fn collect(self, mut collector: ecs::extension::Collector<'_>) { let Some(global_session) = SlangGlobalSession::new() else { tracing::error!("Unable to create global shader-slang session"); return; }; let session_options = shader_slang::CompilerOptions::default() .optimization(shader_slang::OptimizationLevel::None) .matrix_layout_column(true) .debug_information(SlangDebugInfoLevel::Maximal) .no_mangle(true); let target_desc = shader_slang::TargetDesc::default() .format(shader_slang::CompileTarget::Glsl) // .format(shader_slang::CompileTarget::Spirv) .profile(global_session.find_profile("glsl_330")); // .profile(global_session.find_profile("spirv_1_5")); let targets = [target_desc]; let session_desc = shader_slang::SessionDesc::default() .targets(&targets) .search_paths(&[""]) .options(&session_options); let Some(session) = global_session.create_session(&session_desc) else { tracing::error!("Failed to create shader-slang session"); return; }; collector .add_sole(Context { _global_session: global_session, session, modules: HashMap::new(), programs: HashMap::new(), }) .ok(); collector.add_declared_entity(&IMPORT_SHADERS_PHASE); collector.add_system(*START_PHASE, initialize); collector.add_system(*IMPORT_SHADERS_PHASE, load_modules); collector.add_system( *PRE_RENDER_PHASE, default_shader_enqueue_set_shader_bindings, ); } } fn initialize(mut assets: Single) { assets.store_with_label( ASSET_LABEL.clone(), ModuleSource { name: "default_shader.slang".into(), file_path: Path::new("@engine/default_shader").into(), source: include_str!("../res/default_shader.slang").into(), link_entrypoints: EntrypointFlags::VERTEX | EntrypointFlags::FRAGMENT, }, ); } #[tracing::instrument(skip_all)] fn load_modules(mut context: Single, assets: Single) { for AssetEvent::Stored(asset_id, asset_label) in assets.events().last_tick_events() { let asset_handle = AssetHandle::::from_id(*asset_id); if !assets.is_loaded_and_has_type(&asset_handle) { continue; } let Some(module_source) = assets.get(&asset_handle) else { unreachable!(); }; tracing::debug!(asset_label=?asset_label, "Loading shader module"); let module = match load_module(&context.session, module_source) { Ok(module) => module, Err(err) => { tracing::error!("Failed to load shader module: {err}"); continue; } }; if !module_source.link_entrypoints.is_empty() { assert!(context.programs.get(asset_id).is_none()); let Some(vertex_shader_entry_point) = module.get_entry_point("vertex_main") else { tracing::error!( "Shader module does not contain a vertex shader entry point" ); continue; }; let Some(fragment_shader_entry_point) = module.get_entry_point("fragment_main") else { tracing::error!( "Shader module don't contain a fragment_main entry point" ); continue; }; let shader_program = match context.compose_into_program( &[&module], &[&vertex_shader_entry_point, &fragment_shader_entry_point], ) { Ok(shader_program) => shader_program, Err(err) => { tracing::error!("Failed to compose shader into program: {err}"); continue; } }; let linked_shader_program = match shader_program.link() { Ok(linked_shader_program) => linked_shader_program, Err(err) => { tracing::error!("Failed to link shader: {err}"); continue; } }; context.programs.insert(*asset_id, linked_shader_program); } context.modules.insert(*asset_id, module); } } fn load_module( session: &SlangSession, module_source: &ModuleSource, ) -> Result { let module = session.load_module_from_source_string( &module_source.name, &module_source.file_path.to_string_lossy(), &module_source.source, )?; Ok(Module { inner: module }) }