use smallvec::SmallVec;
use crate::{debug_label::DebugLabel, RenderContext};
use super::{
pipeline_layout_pool::{GpuPipelineLayoutHandle, GpuPipelineLayoutPool},
resource::PoolError,
shader_module_pool::{GpuShaderModuleHandle, GpuShaderModulePool},
static_resource_pool::{StaticResourcePool, StaticResourcePoolReadLockAccessor},
};
slotmap::new_key_type! { pub struct GpuRenderPipelineHandle; }
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct VertexBufferLayout {
pub array_stride: wgpu::BufferAddress,
pub step_mode: wgpu::VertexStepMode,
pub attributes: SmallVec<[wgpu::VertexAttribute; 8]>,
}
impl VertexBufferLayout {
pub fn from_formats(formats: impl Iterator<Item = wgpu::VertexFormat>) -> SmallVec<[Self; 4]> {
formats
.enumerate()
.map(move |(location, format)| Self {
array_stride: format.size(),
step_mode: wgpu::VertexStepMode::Vertex,
attributes: smallvec::smallvec![wgpu::VertexAttribute {
format,
offset: 0,
shader_location: location as u32,
}],
})
.collect()
}
pub fn attributes_from_formats(
start_location: u32,
formats: impl Iterator<Item = wgpu::VertexFormat>,
) -> SmallVec<[wgpu::VertexAttribute; 8]> {
let mut offset = 0;
formats
.enumerate()
.map(move |(location, format)| {
let attribute = wgpu::VertexAttribute {
format,
offset,
shader_location: start_location + location as u32,
};
offset += format.size();
attribute
})
.collect()
}
}
impl VertexBufferLayout {
fn to_wgpu_desc(&self) -> wgpu::VertexBufferLayout<'_> {
wgpu::VertexBufferLayout {
array_stride: self.array_stride,
step_mode: self.step_mode,
attributes: &self.attributes,
}
}
}
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct RenderPipelineDesc {
pub label: DebugLabel,
pub pipeline_layout: GpuPipelineLayoutHandle,
pub vertex_entrypoint: String,
pub vertex_handle: GpuShaderModuleHandle,
pub fragment_entrypoint: String,
pub fragment_handle: GpuShaderModuleHandle,
pub vertex_buffers: SmallVec<[VertexBufferLayout; 4]>,
pub render_targets: SmallVec<[Option<wgpu::ColorTargetState>; 4]>,
pub primitive: wgpu::PrimitiveState,
pub depth_stencil: Option<wgpu::DepthStencilState>,
pub multisample: wgpu::MultisampleState,
}
#[derive(thiserror::Error, Debug)]
pub enum RenderPipelineCreationError {
#[error("Referenced pipeline layout not found: {0}")]
PipelineLayout(PoolError),
#[error("Referenced vertex shader not found: {0}")]
VertexShaderNotFound(PoolError),
#[error("Referenced fragment shader not found: {0}")]
FragmentShaderNotFound(PoolError),
}
impl RenderPipelineDesc {
fn create_render_pipeline(
&self,
device: &wgpu::Device,
pipeline_layouts: &GpuPipelineLayoutPool,
shader_modules: &GpuShaderModulePool,
) -> Result<wgpu::RenderPipeline, RenderPipelineCreationError> {
let pipeline_layouts = pipeline_layouts.resources();
let pipeline_layout = pipeline_layouts
.get(self.pipeline_layout)
.map_err(RenderPipelineCreationError::PipelineLayout)?;
let shader_modules = shader_modules.resources();
let vertex_shader_module = shader_modules
.get(self.vertex_handle)
.map_err(RenderPipelineCreationError::VertexShaderNotFound)?;
let fragment_shader_module = shader_modules
.get(self.fragment_handle)
.map_err(RenderPipelineCreationError::FragmentShaderNotFound)?;
let buffers = self
.vertex_buffers
.iter()
.map(|b| b.to_wgpu_desc())
.collect::<Vec<_>>();
Ok(
device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: self.label.get(),
layout: Some(pipeline_layout),
vertex: wgpu::VertexState {
module: vertex_shader_module,
entry_point: &self.vertex_entrypoint,
buffers: &buffers,
compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: wgpu::FragmentState {
module: fragment_shader_module,
entry_point: &self.fragment_entrypoint,
targets: &self.render_targets,
compilation_options: wgpu::PipelineCompilationOptions::default(),
}
.into(),
primitive: self.primitive,
depth_stencil: self.depth_stencil.clone(),
multisample: self.multisample,
multiview: None, cache: None,
}),
)
}
}
pub type GpuRenderPipelinePoolAccessor<'a> =
StaticResourcePoolReadLockAccessor<'a, GpuRenderPipelineHandle, wgpu::RenderPipeline>;
#[derive(Default)]
pub struct GpuRenderPipelinePool {
pool: StaticResourcePool<GpuRenderPipelineHandle, RenderPipelineDesc, wgpu::RenderPipeline>,
}
impl GpuRenderPipelinePool {
pub fn get_or_create(
&self,
ctx: &RenderContext,
desc: &RenderPipelineDesc,
) -> GpuRenderPipelineHandle {
self.pool.get_or_create(desc, |desc| {
sanity_check_vertex_buffers(&desc.vertex_buffers);
desc.create_render_pipeline(
&ctx.device,
&ctx.gpu_resources.pipeline_layouts,
&ctx.gpu_resources.shader_modules,
)
.unwrap()
})
}
pub fn begin_frame(
&mut self,
device: &wgpu::Device,
frame_index: u64,
shader_modules: &GpuShaderModulePool,
pipeline_layouts: &GpuPipelineLayoutPool,
) {
re_tracing::profile_function!();
self.pool.current_frame_index = frame_index;
self.pool.recreate_resources(|desc| {
let frame_created = {
let shader_modules = shader_modules.resources();
let vertex_created = shader_modules
.get_statistics(desc.vertex_handle)
.map(|sm| sm.frame_created)
.unwrap_or(0);
let fragment_created = shader_modules
.get_statistics(desc.fragment_handle)
.map(|sm| sm.frame_created)
.unwrap_or(0);
u64::max(vertex_created, fragment_created)
};
if frame_created < frame_index {
return None;
}
match desc.create_render_pipeline(device, pipeline_layouts, shader_modules) {
Ok(sm) => {
re_log::info!(label = desc.label.get(), "recompiled render pipeline");
Some(sm)
}
Err(err) => {
re_log::error!("Failed to compile render pipeline: {}", err);
None
}
}
});
}
pub fn resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, GpuRenderPipelineHandle, wgpu::RenderPipeline> {
self.pool.resources()
}
pub fn num_resources(&self) -> usize {
self.pool.num_resources()
}
}
fn sanity_check_vertex_buffers(buffers: &[VertexBufferLayout]) {
if buffers.is_empty() {
return;
}
let mut locations = std::collections::BTreeSet::<u32>::default();
let mut num_attributes: u32 = 0;
for buffer in buffers {
for attribute in &buffer.attributes {
num_attributes += 1;
assert!(
locations.insert(attribute.shader_location),
"Duplicate shader location {} in vertex buffers",
attribute.shader_location
);
}
}
for i in 0..num_attributes {
assert!(
locations.contains(&i),
"Missing shader location {i} in vertex buffers"
);
}
}