Skip to main content

mtl_gpu/pipeline/
render_state.rs

1//! Render pipeline state.
2//!
3//! Corresponds to `MTL::RenderPipelineState`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use crate::enums::{RenderStages, ShaderValidation};
12use crate::types::{ResourceID, Size};
13
14use super::RenderPipelineFunctionsDescriptor;
15use super::RenderPipelineReflection;
16
17/// A compiled render pipeline configuration.
18///
19/// C++ equivalent: `MTL::RenderPipelineState`
20#[repr(transparent)]
21pub struct RenderPipelineState(pub(crate) NonNull<c_void>);
22
23impl RenderPipelineState {
24    /// Create a RenderPipelineState from a raw pointer.
25    ///
26    /// # Safety
27    ///
28    /// The pointer must be a valid Metal render pipeline state object.
29    #[inline]
30    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
31        NonNull::new(ptr).map(Self)
32    }
33
34    /// Get the raw pointer.
35    #[inline]
36    pub fn as_raw(&self) -> *mut c_void {
37        self.0.as_ptr()
38    }
39
40    // =========================================================================
41    // Properties
42    // =========================================================================
43
44    /// Get the label.
45    ///
46    /// C++ equivalent: `NS::String* label() const`
47    pub fn label(&self) -> Option<String> {
48        unsafe {
49            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
50            if ptr.is_null() {
51                return None;
52            }
53            let utf8_ptr: *const std::ffi::c_char =
54                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
55            if utf8_ptr.is_null() {
56                return None;
57            }
58            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
59            Some(c_str.to_string_lossy().into_owned())
60        }
61    }
62
63    /// Get the device.
64    ///
65    /// C++ equivalent: `Device* device() const`
66    pub fn device(&self) -> crate::Device {
67        unsafe {
68            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
69            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
70            crate::Device::from_raw(ptr).expect("pipeline state has no device")
71        }
72    }
73
74    /// Get the maximum total threads per threadgroup.
75    ///
76    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerThreadgroup() const`
77    #[inline]
78    pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
79        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
80    }
81
82    /// Check if threadgroup size is multiple of thread execution width.
83    ///
84    /// C++ equivalent: `bool threadgroupSizeMatchesTileSize() const`
85    #[inline]
86    pub fn threadgroup_size_matches_tile_size(&self) -> bool {
87        unsafe { msg_send_0(self.as_ptr(), sel!(threadgroupSizeMatchesTileSize)) }
88    }
89
90    /// Get the imageblock sample length.
91    ///
92    /// C++ equivalent: `NS::UInteger imageblockSampleLength() const`
93    #[inline]
94    pub fn imageblock_sample_length(&self) -> UInteger {
95        unsafe { msg_send_0(self.as_ptr(), sel!(imageblockSampleLength)) }
96    }
97
98    /// Check if alpha-to-coverage is supported.
99    ///
100    /// C++ equivalent: `bool supportIndirectCommandBuffers() const`
101    #[inline]
102    pub fn support_indirect_command_buffers(&self) -> bool {
103        unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
104    }
105
106    /// Get the GPU resource ID for bindless access.
107    ///
108    /// C++ equivalent: `ResourceID gpuResourceID() const`
109    #[inline]
110    pub fn gpu_resource_id(&self) -> ResourceID {
111        unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
112    }
113
114    /// Get the shader validation mode.
115    ///
116    /// C++ equivalent: `ShaderValidation shaderValidation() const`
117    #[inline]
118    pub fn shader_validation(&self) -> ShaderValidation {
119        unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
120    }
121
122    /// Get the imageblock memory length for given imageblock dimensions.
123    ///
124    /// C++ equivalent: `NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions)`
125    #[inline]
126    pub fn imageblock_memory_length(&self, dimensions: Size) -> UInteger {
127        unsafe {
128            msg_send_1(
129                self.as_ptr(),
130                sel!(imageblockMemoryLengthForDimensions:),
131                dimensions,
132            )
133        }
134    }
135
136    // =========================================================================
137    // Mesh/Object Shader Properties
138    // =========================================================================
139
140    /// Get the maximum total threadgroups per mesh grid.
141    ///
142    /// C++ equivalent: `NS::UInteger maxTotalThreadgroupsPerMeshGrid() const`
143    #[inline]
144    pub fn max_total_threadgroups_per_mesh_grid(&self) -> UInteger {
145        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadgroupsPerMeshGrid)) }
146    }
147
148    /// Get the maximum total threads per mesh threadgroup.
149    ///
150    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerMeshThreadgroup() const`
151    #[inline]
152    pub fn max_total_threads_per_mesh_threadgroup(&self) -> UInteger {
153        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerMeshThreadgroup)) }
154    }
155
156    /// Get the maximum total threads per object threadgroup.
157    ///
158    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerObjectThreadgroup() const`
159    #[inline]
160    pub fn max_total_threads_per_object_threadgroup(&self) -> UInteger {
161        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerObjectThreadgroup)) }
162    }
163
164    /// Get the mesh thread execution width.
165    ///
166    /// C++ equivalent: `NS::UInteger meshThreadExecutionWidth() const`
167    #[inline]
168    pub fn mesh_thread_execution_width(&self) -> UInteger {
169        unsafe { msg_send_0(self.as_ptr(), sel!(meshThreadExecutionWidth)) }
170    }
171
172    /// Get the object thread execution width.
173    ///
174    /// C++ equivalent: `NS::UInteger objectThreadExecutionWidth() const`
175    #[inline]
176    pub fn object_thread_execution_width(&self) -> UInteger {
177        unsafe { msg_send_0(self.as_ptr(), sel!(objectThreadExecutionWidth)) }
178    }
179
180    /// Get the required threads per mesh threadgroup.
181    ///
182    /// C++ equivalent: `Size requiredThreadsPerMeshThreadgroup() const`
183    #[inline]
184    pub fn required_threads_per_mesh_threadgroup(&self) -> Size {
185        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerMeshThreadgroup)) }
186    }
187
188    /// Get the required threads per object threadgroup.
189    ///
190    /// C++ equivalent: `Size requiredThreadsPerObjectThreadgroup() const`
191    #[inline]
192    pub fn required_threads_per_object_threadgroup(&self) -> Size {
193        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerObjectThreadgroup)) }
194    }
195
196    /// Get the required threads per tile threadgroup.
197    ///
198    /// C++ equivalent: `Size requiredThreadsPerTileThreadgroup() const`
199    #[inline]
200    pub fn required_threads_per_tile_threadgroup(&self) -> Size {
201        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerTileThreadgroup)) }
202    }
203
204    // =========================================================================
205    // Function Handles
206    // =========================================================================
207
208    /// Get a function handle by name and stage.
209    ///
210    /// C++ equivalent: `FunctionHandle* functionHandle(const NS::String* name, MTL::RenderStages stage)`
211    pub fn function_handle_with_name(
212        &self,
213        name: &str,
214        stage: RenderStages,
215    ) -> Option<crate::FunctionHandle> {
216        let ns_name = mtl_foundation::String::from_str(name)?;
217        unsafe {
218            let ptr: *mut c_void = mtl_sys::msg_send_2(
219                self.as_ptr(),
220                sel!(functionHandleWithFunction: stage:),
221                ns_name.as_ptr(),
222                stage,
223            );
224            if ptr.is_null() {
225                return None;
226            }
227            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
228            crate::FunctionHandle::from_raw(ptr)
229        }
230    }
231
232    /// Get a function handle from a function and stage.
233    ///
234    /// C++ equivalent: `FunctionHandle* functionHandle(const MTL::Function* function, MTL::RenderStages stage)`
235    pub fn function_handle_with_function(
236        &self,
237        function: &crate::Function,
238        stage: RenderStages,
239    ) -> Option<crate::FunctionHandle> {
240        unsafe {
241            let ptr: *mut c_void = mtl_sys::msg_send_2(
242                self.as_ptr(),
243                sel!(functionHandleWithFunction: stage:),
244                function.as_ptr(),
245                stage,
246            );
247            if ptr.is_null() {
248                return None;
249            }
250            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
251            crate::FunctionHandle::from_raw(ptr)
252        }
253    }
254
255    // =========================================================================
256    // Function Tables
257    // =========================================================================
258
259    /// Create a new intersection function table.
260    ///
261    /// C++ equivalent: `IntersectionFunctionTable* newIntersectionFunctionTable(const IntersectionFunctionTableDescriptor*, RenderStages)`
262    pub fn new_intersection_function_table(
263        &self,
264        descriptor: &crate::IntersectionFunctionTableDescriptor,
265        stage: RenderStages,
266    ) -> Option<crate::IntersectionFunctionTable> {
267        unsafe {
268            let ptr: *mut c_void = mtl_sys::msg_send_2(
269                self.as_ptr(),
270                sel!(newIntersectionFunctionTableWithDescriptor: stage:),
271                descriptor.as_ptr(),
272                stage,
273            );
274            if ptr.is_null() {
275                None
276            } else {
277                crate::IntersectionFunctionTable::from_raw(ptr)
278            }
279        }
280    }
281
282    /// Create a new visible function table.
283    ///
284    /// C++ equivalent: `VisibleFunctionTable* newVisibleFunctionTable(const VisibleFunctionTableDescriptor*, RenderStages)`
285    pub fn new_visible_function_table(
286        &self,
287        descriptor: &crate::VisibleFunctionTableDescriptor,
288        stage: RenderStages,
289    ) -> Option<crate::VisibleFunctionTable> {
290        unsafe {
291            let ptr: *mut c_void = mtl_sys::msg_send_2(
292                self.as_ptr(),
293                sel!(newVisibleFunctionTableWithDescriptor: stage:),
294                descriptor.as_ptr(),
295                stage,
296            );
297            if ptr.is_null() {
298                None
299            } else {
300                crate::VisibleFunctionTable::from_raw(ptr)
301            }
302        }
303    }
304
305    // =========================================================================
306    // Pipeline State Creation
307    // =========================================================================
308
309    /// Create a new render pipeline state with additional binary functions.
310    ///
311    /// C++ equivalent: `RenderPipelineState* newRenderPipelineState(const RenderPipelineFunctionsDescriptor*, NS::Error**)`
312    pub fn new_render_pipeline_state(
313        &self,
314        additional_functions: &RenderPipelineFunctionsDescriptor,
315    ) -> Result<RenderPipelineState, mtl_foundation::Error> {
316        unsafe {
317            let mut error: *mut c_void = std::ptr::null_mut();
318            let ptr: *mut c_void = mtl_sys::msg_send_2(
319                self.as_ptr(),
320                sel!(newRenderPipelineStateWithAdditionalBinaryFunctions: error:),
321                additional_functions.as_ptr(),
322                &mut error as *mut _,
323            );
324            if ptr.is_null() {
325                if !error.is_null() {
326                    return Err(
327                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
328                    );
329                }
330                return Err(mtl_foundation::Error::error(
331                    std::ptr::null_mut(),
332                    -1,
333                    std::ptr::null_mut(),
334                )
335                .expect("failed to create error"));
336            }
337            Ok(RenderPipelineState::from_raw(ptr).expect("failed to create pipeline state"))
338        }
339    }
340
341    // =========================================================================
342    // Reflection
343    // =========================================================================
344
345    /// Get the pipeline reflection information.
346    ///
347    /// C++ equivalent: `RenderPipelineReflection* reflection() const`
348    pub fn reflection(&self) -> Option<RenderPipelineReflection> {
349        unsafe {
350            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(reflection));
351            if ptr.is_null() {
352                return None;
353            }
354            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
355            RenderPipelineReflection::from_raw(ptr)
356        }
357    }
358}
359
360impl Clone for RenderPipelineState {
361    fn clone(&self) -> Self {
362        unsafe {
363            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
364        }
365        Self(self.0)
366    }
367}
368
369impl Drop for RenderPipelineState {
370    fn drop(&mut self) {
371        unsafe {
372            msg_send_0::<()>(self.as_ptr(), sel!(release));
373        }
374    }
375}
376
377impl Referencing for RenderPipelineState {
378    #[inline]
379    fn as_ptr(&self) -> *const c_void {
380        self.0.as_ptr()
381    }
382}
383
384unsafe impl Send for RenderPipelineState {}
385unsafe impl Sync for RenderPipelineState {}
386
387impl std::fmt::Debug for RenderPipelineState {
388    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        f.debug_struct("RenderPipelineState")
390            .field("label", &self.label())
391            .finish()
392    }
393}