Skip to main content

mtl_gpu/pipeline/
compute_state.rs

1//! Compute pipeline state.
2//!
3//! Corresponds to `MTL::ComputePipelineState`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
10
11use crate::enums::ShaderValidation;
12use crate::types::{ResourceID, Size};
13
14use super::ComputePipelineReflection;
15
16/// A compiled compute pipeline configuration.
17///
18/// C++ equivalent: `MTL::ComputePipelineState`
19#[repr(transparent)]
20pub struct ComputePipelineState(pub(crate) NonNull<c_void>);
21
22impl ComputePipelineState {
23    /// Create a ComputePipelineState from a raw pointer.
24    ///
25    /// # Safety
26    ///
27    /// The pointer must be a valid Metal compute pipeline state object.
28    #[inline]
29    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30        NonNull::new(ptr).map(Self)
31    }
32
33    /// Get the raw pointer.
34    #[inline]
35    pub fn as_raw(&self) -> *mut c_void {
36        self.0.as_ptr()
37    }
38
39    // =========================================================================
40    // Properties
41    // =========================================================================
42
43    /// Get the label.
44    ///
45    /// C++ equivalent: `NS::String* label() const`
46    pub fn label(&self) -> Option<String> {
47        unsafe {
48            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
49            if ptr.is_null() {
50                return None;
51            }
52            let utf8_ptr: *const std::ffi::c_char =
53                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
54            if utf8_ptr.is_null() {
55                return None;
56            }
57            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
58            Some(c_str.to_string_lossy().into_owned())
59        }
60    }
61
62    /// Get the device.
63    ///
64    /// C++ equivalent: `Device* device() const`
65    pub fn device(&self) -> crate::Device {
66        unsafe {
67            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
68            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
69            crate::Device::from_raw(ptr).expect("pipeline state has no device")
70        }
71    }
72
73    /// Get the maximum total threads per threadgroup.
74    ///
75    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerThreadgroup() const`
76    #[inline]
77    pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
78        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
79    }
80
81    /// Get the thread execution width.
82    ///
83    /// C++ equivalent: `NS::UInteger threadExecutionWidth() const`
84    #[inline]
85    pub fn thread_execution_width(&self) -> UInteger {
86        unsafe { msg_send_0(self.as_ptr(), sel!(threadExecutionWidth)) }
87    }
88
89    /// Get the static threadgroup memory length.
90    ///
91    /// C++ equivalent: `NS::UInteger staticThreadgroupMemoryLength() const`
92    #[inline]
93    pub fn static_threadgroup_memory_length(&self) -> UInteger {
94        unsafe { msg_send_0(self.as_ptr(), sel!(staticThreadgroupMemoryLength)) }
95    }
96
97    /// Check if the pipeline supports indirect command buffers.
98    ///
99    /// C++ equivalent: `bool supportIndirectCommandBuffers() const`
100    #[inline]
101    pub fn support_indirect_command_buffers(&self) -> bool {
102        unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
103    }
104
105    /// Get the imageblock memory length for given imageblock dimensions.
106    ///
107    /// C++ equivalent: `NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions)`
108    #[inline]
109    pub fn imageblock_memory_length(&self, dimensions: Size) -> UInteger {
110        unsafe {
111            msg_send_1(
112                self.as_ptr(),
113                sel!(imageblockMemoryLengthForDimensions:),
114                dimensions,
115            )
116        }
117    }
118
119    /// Get the GPU resource ID for bindless access.
120    ///
121    /// C++ equivalent: `ResourceID gpuResourceID() const`
122    #[inline]
123    pub fn gpu_resource_id(&self) -> ResourceID {
124        unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
125    }
126
127    /// Get the shader validation mode.
128    ///
129    /// C++ equivalent: `ShaderValidation shaderValidation() const`
130    #[inline]
131    pub fn shader_validation(&self) -> ShaderValidation {
132        unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
133    }
134
135    /// Get the required threads per threadgroup.
136    ///
137    /// C++ equivalent: `Size requiredThreadsPerThreadgroup() const`
138    #[inline]
139    pub fn required_threads_per_threadgroup(&self) -> Size {
140        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
141    }
142
143    // =========================================================================
144    // Function Handles
145    // =========================================================================
146
147    /// Get a function handle for a function by name.
148    ///
149    /// C++ equivalent: `FunctionHandle* functionHandle(const NS::String* name)`
150    pub fn function_handle_with_name(&self, name: &str) -> Option<crate::FunctionHandle> {
151        let ns_name = mtl_foundation::String::from_str(name)?;
152        unsafe {
153            let ptr: *mut c_void = msg_send_1(
154                self.as_ptr(),
155                sel!(functionHandleWithName:),
156                ns_name.as_ptr(),
157            );
158            crate::FunctionHandle::from_raw(ptr)
159        }
160    }
161
162    /// Get a function handle for a function.
163    ///
164    /// C++ equivalent: `FunctionHandle* functionHandle(const MTL::Function* function)`
165    pub fn function_handle_with_function(
166        &self,
167        function: &crate::Function,
168    ) -> Option<crate::FunctionHandle> {
169        unsafe {
170            let ptr: *mut c_void = msg_send_1(
171                self.as_ptr(),
172                sel!(functionHandleWithFunction:),
173                function.as_ptr(),
174            );
175            crate::FunctionHandle::from_raw(ptr)
176        }
177    }
178
179    /// Get a function handle for a binary function.
180    ///
181    /// C++ equivalent: `FunctionHandle* functionHandle(const MTL4::BinaryFunction* function)`
182    pub fn function_handle_with_binary_function(
183        &self,
184        binary_function: *const c_void,
185    ) -> Option<crate::FunctionHandle> {
186        unsafe {
187            let ptr: *mut c_void = msg_send_1(
188                self.as_ptr(),
189                sel!(functionHandleWithBinaryFunction:),
190                binary_function,
191            );
192            crate::FunctionHandle::from_raw(ptr)
193        }
194    }
195
196    // =========================================================================
197    // Pipeline State Creation
198    // =========================================================================
199
200    /// Create a new compute pipeline state with additional functions.
201    ///
202    /// C++ equivalent: `ComputePipelineState* newComputePipelineState(const NS::Array* functions, NS::Error** error)`
203    pub fn new_compute_pipeline_state_with_functions(
204        &self,
205        functions: *const c_void,
206    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
207        unsafe {
208            let mut error: *mut c_void = std::ptr::null_mut();
209            let ptr: *mut c_void = msg_send_2(
210                self.as_ptr(),
211                sel!(newComputePipelineStateWithFunctions:error:),
212                functions,
213                &mut error,
214            );
215            if ptr.is_null() {
216                if !error.is_null() {
217                    return Err(
218                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
219                    );
220                }
221                return Err(mtl_foundation::Error::error(
222                    std::ptr::null_mut(),
223                    -1,
224                    std::ptr::null_mut(),
225                )
226                .expect("failed to create error"));
227            }
228            Ok(ComputePipelineState::from_raw(ptr).unwrap())
229        }
230    }
231
232    /// Create a new compute pipeline state with additional binary functions.
233    ///
234    /// C++ equivalent: `ComputePipelineState* newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error)`
235    pub fn new_compute_pipeline_state_with_binary_functions(
236        &self,
237        binary_functions: *const c_void,
238    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
239        unsafe {
240            let mut error: *mut c_void = std::ptr::null_mut();
241            let ptr: *mut c_void = msg_send_2(
242                self.as_ptr(),
243                sel!(newComputePipelineStateWithBinaryFunctions:error:),
244                binary_functions,
245                &mut error,
246            );
247            if ptr.is_null() {
248                if !error.is_null() {
249                    return Err(
250                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
251                    );
252                }
253                return Err(mtl_foundation::Error::error(
254                    std::ptr::null_mut(),
255                    -1,
256                    std::ptr::null_mut(),
257                )
258                .expect("failed to create error"));
259            }
260            Ok(ComputePipelineState::from_raw(ptr).unwrap())
261        }
262    }
263
264    // =========================================================================
265    // Function Tables
266    // =========================================================================
267
268    /// Create a new intersection function table.
269    ///
270    /// C++ equivalent: `IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor)`
271    pub fn new_intersection_function_table(
272        &self,
273        descriptor: &crate::IntersectionFunctionTableDescriptor,
274    ) -> Option<crate::IntersectionFunctionTable> {
275        unsafe {
276            let ptr: *mut c_void = msg_send_1(
277                self.as_ptr(),
278                sel!(newIntersectionFunctionTableWithDescriptor:),
279                descriptor.as_ptr(),
280            );
281            crate::IntersectionFunctionTable::from_raw(ptr)
282        }
283    }
284
285    /// Create a new visible function table.
286    ///
287    /// C++ equivalent: `VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor)`
288    pub fn new_visible_function_table(
289        &self,
290        descriptor: &crate::VisibleFunctionTableDescriptor,
291    ) -> Option<crate::VisibleFunctionTable> {
292        unsafe {
293            let ptr: *mut c_void = msg_send_1(
294                self.as_ptr(),
295                sel!(newVisibleFunctionTableWithDescriptor:),
296                descriptor.as_ptr(),
297            );
298            crate::VisibleFunctionTable::from_raw(ptr)
299        }
300    }
301
302    // =========================================================================
303    // Reflection
304    // =========================================================================
305
306    /// Get the pipeline reflection information.
307    ///
308    /// C++ equivalent: `ComputePipelineReflection* reflection()`
309    pub fn reflection(&self) -> Option<ComputePipelineReflection> {
310        unsafe {
311            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(reflection));
312            ComputePipelineReflection::from_raw(ptr)
313        }
314    }
315}
316
317impl Clone for ComputePipelineState {
318    fn clone(&self) -> Self {
319        unsafe {
320            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
321        }
322        Self(self.0)
323    }
324}
325
326impl Drop for ComputePipelineState {
327    fn drop(&mut self) {
328        unsafe {
329            msg_send_0::<()>(self.as_ptr(), sel!(release));
330        }
331    }
332}
333
334impl Referencing for ComputePipelineState {
335    #[inline]
336    fn as_ptr(&self) -> *const c_void {
337        self.0.as_ptr()
338    }
339}
340
341unsafe impl Send for ComputePipelineState {}
342unsafe impl Sync for ComputePipelineState {}
343
344impl std::fmt::Debug for ComputePipelineState {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        f.debug_struct("ComputePipelineState")
347            .field("label", &self.label())
348            .field(
349                "max_total_threads_per_threadgroup",
350                &self.max_total_threads_per_threadgroup(),
351            )
352            .field("thread_execution_width", &self.thread_execution_width())
353            .finish()
354    }
355}