use std::{hash::Hash, path::PathBuf};
use ahash::HashSet;
use anyhow::Context as _;
use crate::{debug_label::DebugLabel, FileResolver, FileSystem, RenderContext};
use super::static_resource_pool::{StaticResourcePool, StaticResourcePoolReadLockAccessor};
slotmap::new_key_type! { pub struct GpuShaderModuleHandle; }
const RERUN_WGSL_SHADER_DUMP_PATH: &str = "RERUN_WGSL_SHADER_DUMP_PATH";
#[macro_export]
macro_rules! include_shader_module {
($path:expr $(,)?) => {{
$crate::wgpu_resources::ShaderModuleDesc {
label: $crate::DebugLabel::from(stringify!($path).strip_prefix("../../shader/")),
source: $crate::include_file!($path),
extra_workaround_replacements: Vec::new(),
}
}};
}
#[derive(Clone, Eq, Debug)]
pub struct ShaderModuleDesc {
pub label: DebugLabel,
pub source: PathBuf,
pub extra_workaround_replacements: Vec<(String, String)>,
}
impl PartialEq for ShaderModuleDesc {
fn eq(&self, rhs: &Self) -> bool {
self.source.eq(&rhs.source)
}
}
impl Hash for ShaderModuleDesc {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.source.hash(state);
self.extra_workaround_replacements.hash(state);
}
}
impl ShaderModuleDesc {
fn create_shader_module<Fs: FileSystem>(
&self,
device: &wgpu::Device,
resolver: &FileResolver<Fs>,
shader_text_workaround_replacements: &[(String, String)],
) -> wgpu::ShaderModule {
let mut source_interpolated = resolver
.populate(&self.source)
.context("couldn't resolve shader module's contents")
.map_err(|err| re_log::error!(err=%re_error::format(err)))
.unwrap_or_default();
for (from, to) in shader_text_workaround_replacements
.iter()
.chain(self.extra_workaround_replacements.iter())
{
source_interpolated.contents = source_interpolated.contents.replace(from, to);
}
if let Ok(wgsl_dump_dir) = std::env::var(RERUN_WGSL_SHADER_DUMP_PATH) {
let mut path = PathBuf::from(wgsl_dump_dir);
std::fs::create_dir_all(&path).unwrap();
let mut wgsl_filename = self.source.to_str().unwrap().replace(['/', '\\'], "_");
if let Some(position) = wgsl_filename.find("re_renderer_shader_") {
wgsl_filename = wgsl_filename[position + "re_renderer_shader_".len()..].to_owned();
}
path.push(&wgsl_filename);
std::fs::write(&path, &source_interpolated.contents).unwrap();
}
device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: self.label.get(),
source: wgpu::ShaderSource::Wgsl(source_interpolated.contents.into()),
})
}
}
#[derive(Default)]
pub struct GpuShaderModulePool {
pool: StaticResourcePool<GpuShaderModuleHandle, ShaderModuleDesc, wgpu::ShaderModule>,
pub shader_text_workaround_replacements: Vec<(String, String)>,
}
impl GpuShaderModulePool {
pub fn get_or_create(
&self,
ctx: &RenderContext,
desc: &ShaderModuleDesc,
) -> GpuShaderModuleHandle {
self.pool.get_or_create(desc, |desc| {
desc.create_shader_module(
&ctx.device,
&ctx.resolver,
&self.shader_text_workaround_replacements,
)
})
}
pub fn begin_frame<Fs: FileSystem>(
&mut self,
device: &wgpu::Device,
resolver: &FileResolver<Fs>,
frame_index: u64,
updated_paths: &HashSet<PathBuf>,
) {
self.pool.current_frame_index = frame_index;
if updated_paths.is_empty() {
return;
}
self.pool.recreate_resources(|desc| {
let mut paths = vec![desc.source.clone()];
if let Ok(source_interpolated) = resolver.populate(&desc.source) {
paths.extend(source_interpolated.imports);
}
paths.iter().any(|p| updated_paths.contains(p)).then(|| {
let shader_module = desc.create_shader_module(
device,
resolver,
&self.shader_text_workaround_replacements,
);
re_log::debug!(?desc.source, label = desc.label.get(), "recompiled shader module");
shader_module
})
});
}
pub fn resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, GpuShaderModuleHandle, wgpu::ShaderModule> {
self.pool.resources()
}
pub fn num_resources(&self) -> usize {
self.pool.num_resources()
}
}