1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
use smallvec::SmallVec;

use crate::{debug_label::DebugLabel, RenderContext};

use super::{
    pipeline_layout_pool::{GpuPipelineLayoutHandle, GpuPipelineLayoutPool},
    resource::PoolError,
    shader_module_pool::{GpuShaderModuleHandle, GpuShaderModulePool},
    static_resource_pool::{StaticResourcePool, StaticResourcePoolReadLockAccessor},
};

slotmap::new_key_type! { pub struct GpuRenderPipelineHandle; }

/// A copy of [`wgpu::VertexBufferLayout`] with a [`smallvec`] for the attributes.
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct VertexBufferLayout {
    /// The stride, in bytes, between elements of this buffer.
    pub array_stride: wgpu::BufferAddress,

    /// How often this vertex buffer is "stepped" forward.
    pub step_mode: wgpu::VertexStepMode,

    /// The list of attributes which comprise a single vertex.
    pub attributes: SmallVec<[wgpu::VertexAttribute; 8]>,
}

impl VertexBufferLayout {
    /// Generates layouts with successive shader locations without gaps.
    pub fn from_formats(formats: impl Iterator<Item = wgpu::VertexFormat>) -> SmallVec<[Self; 4]> {
        formats
            .enumerate()
            .map(move |(location, format)| Self {
                array_stride: format.size(),
                step_mode: wgpu::VertexStepMode::Vertex,
                attributes: smallvec::smallvec![wgpu::VertexAttribute {
                    format,
                    offset: 0,
                    shader_location: location as u32,
                }],
            })
            .collect()
    }

    /// Generates attributes with successive shader locations without gaps
    pub fn attributes_from_formats(
        start_location: u32,
        formats: impl Iterator<Item = wgpu::VertexFormat>,
    ) -> SmallVec<[wgpu::VertexAttribute; 8]> {
        let mut offset = 0;
        formats
            .enumerate()
            .map(move |(location, format)| {
                let attribute = wgpu::VertexAttribute {
                    format,
                    offset,
                    shader_location: start_location + location as u32,
                };
                offset += format.size();
                attribute
            })
            .collect()
    }
}

impl VertexBufferLayout {
    fn to_wgpu_desc(&self) -> wgpu::VertexBufferLayout<'_> {
        wgpu::VertexBufferLayout {
            array_stride: self.array_stride,
            step_mode: self.step_mode,
            attributes: &self.attributes,
        }
    }
}

/// Renderpipeline descriptor, can be converted into [`wgpu::RenderPipeline`] (which isn't hashable or comparable)
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct RenderPipelineDesc {
    /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification.
    pub label: DebugLabel,

    pub pipeline_layout: GpuPipelineLayoutHandle,

    pub vertex_entrypoint: String,
    pub vertex_handle: GpuShaderModuleHandle,
    pub fragment_entrypoint: String,
    pub fragment_handle: GpuShaderModuleHandle,

    /// The format of any vertex buffers used with this pipeline.
    pub vertex_buffers: SmallVec<[VertexBufferLayout; 4]>,

    /// The color state of the render targets.
    pub render_targets: SmallVec<[Option<wgpu::ColorTargetState>; 4]>,

    /// The properties of the pipeline at the primitive assembly and rasterization level.
    pub primitive: wgpu::PrimitiveState,

    /// The effect of draw calls on the depth and stencil aspects of the output target, if any.
    pub depth_stencil: Option<wgpu::DepthStencilState>,

    /// The multi-sampling properties of the pipeline.
    pub multisample: wgpu::MultisampleState,
}

#[derive(thiserror::Error, Debug)]
pub enum RenderPipelineCreationError {
    #[error("Referenced pipeline layout not found: {0}")]
    PipelineLayout(PoolError),

    #[error("Referenced vertex shader not found: {0}")]
    VertexShaderNotFound(PoolError),

    #[error("Referenced fragment shader not found: {0}")]
    FragmentShaderNotFound(PoolError),
}

impl RenderPipelineDesc {
    fn create_render_pipeline(
        &self,
        device: &wgpu::Device,
        pipeline_layouts: &GpuPipelineLayoutPool,
        shader_modules: &GpuShaderModulePool,
    ) -> Result<wgpu::RenderPipeline, RenderPipelineCreationError> {
        let pipeline_layouts = pipeline_layouts.resources();
        let pipeline_layout = pipeline_layouts
            .get(self.pipeline_layout)
            .map_err(RenderPipelineCreationError::PipelineLayout)?;

        let shader_modules = shader_modules.resources();
        let vertex_shader_module = shader_modules
            .get(self.vertex_handle)
            .map_err(RenderPipelineCreationError::VertexShaderNotFound)?;

        let fragment_shader_module = shader_modules
            .get(self.fragment_handle)
            .map_err(RenderPipelineCreationError::FragmentShaderNotFound)?;

        let buffers = self
            .vertex_buffers
            .iter()
            .map(|b| b.to_wgpu_desc())
            .collect::<Vec<_>>();

        Ok(
            device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
                label: self.label.get(),
                layout: Some(pipeline_layout),
                vertex: wgpu::VertexState {
                    module: vertex_shader_module,
                    entry_point: &self.vertex_entrypoint,
                    buffers: &buffers,
                    compilation_options: wgpu::PipelineCompilationOptions::default(),
                },
                fragment: wgpu::FragmentState {
                    module: fragment_shader_module,
                    entry_point: &self.fragment_entrypoint,
                    targets: &self.render_targets,
                    compilation_options: wgpu::PipelineCompilationOptions::default(),
                }
                .into(),
                primitive: self.primitive,
                depth_stencil: self.depth_stencil.clone(),
                multisample: self.multisample,
                multiview: None, // Multi-layered render target support isn't widespread
                cache: None,
            }),
        )
    }
}

pub type GpuRenderPipelinePoolAccessor<'a> =
    StaticResourcePoolReadLockAccessor<'a, GpuRenderPipelineHandle, wgpu::RenderPipeline>;

#[derive(Default)]
pub struct GpuRenderPipelinePool {
    pool: StaticResourcePool<GpuRenderPipelineHandle, RenderPipelineDesc, wgpu::RenderPipeline>,
}

impl GpuRenderPipelinePool {
    pub fn get_or_create(
        &self,
        ctx: &RenderContext,
        desc: &RenderPipelineDesc,
    ) -> GpuRenderPipelineHandle {
        self.pool.get_or_create(desc, |desc| {
            sanity_check_vertex_buffers(&desc.vertex_buffers);

            // TODO(cmc): certainly not unwrapping here
            desc.create_render_pipeline(
                &ctx.device,
                &ctx.gpu_resources.pipeline_layouts,
                &ctx.gpu_resources.shader_modules,
            )
            .unwrap()
        })
    }

    pub fn begin_frame(
        &mut self,
        device: &wgpu::Device,
        frame_index: u64,
        shader_modules: &GpuShaderModulePool,
        pipeline_layouts: &GpuPipelineLayoutPool,
    ) {
        re_tracing::profile_function!();
        self.pool.current_frame_index = frame_index;

        // Recompile render pipelines referencing shader modules that have been recompiled this frame.
        self.pool.recreate_resources(|desc| {
            let frame_created = {
                let shader_modules = shader_modules.resources();
                let vertex_created = shader_modules
                    .get_statistics(desc.vertex_handle)
                    .map(|sm| sm.frame_created)
                    .unwrap_or(0);
                let fragment_created = shader_modules
                    .get_statistics(desc.fragment_handle)
                    .map(|sm| sm.frame_created)
                    .unwrap_or(0);
                u64::max(vertex_created, fragment_created)
            };
            // The frame counter just got bumped by one. So any shader that has `frame_created`,
            // equal the current frame now, must have been recompiled since the user didn't have a
            // chance yet to add new shaders for this frame!
            // (note that this assumes that shader `begin_frame` happens before pipeline `begin_frame`)
            if frame_created < frame_index {
                return None;
            }

            match desc.create_render_pipeline(device, pipeline_layouts, shader_modules) {
                Ok(sm) => {
                    // We don't know yet if this actually succeeded.
                    // But it's good to get feedback to the user that _something_ happened!
                    re_log::info!(label = desc.label.get(), "recompiled render pipeline");
                    Some(sm)
                }
                Err(err) => {
                    re_log::error!("Failed to compile render pipeline: {}", err);
                    None
                }
            }
        });
    }

    /// Locks the resource pool for resolving handles.
    ///
    /// While it is locked, no new resources can be added.
    pub fn resources(
        &self,
    ) -> StaticResourcePoolReadLockAccessor<'_, GpuRenderPipelineHandle, wgpu::RenderPipeline> {
        self.pool.resources()
    }

    pub fn num_resources(&self) -> usize {
        self.pool.num_resources()
    }
}

fn sanity_check_vertex_buffers(buffers: &[VertexBufferLayout]) {
    if buffers.is_empty() {
        return;
    }

    let mut locations = std::collections::BTreeSet::<u32>::default();
    let mut num_attributes: u32 = 0;

    for buffer in buffers {
        for attribute in &buffer.attributes {
            num_attributes += 1;
            assert!(
                locations.insert(attribute.shader_location),
                "Duplicate shader location {} in vertex buffers",
                attribute.shader_location
            );
        }
    }

    for i in 0..num_attributes {
        // This is technically allowed, but weird.
        assert!(
            locations.contains(&i),
            "Missing shader location {i} in vertex buffers"
        );
    }
}