Skip to main content

mtl_gpu/pipeline/
compute_descriptor.rs

1//! Compute pipeline descriptor.
2//!
3//! Corresponds to `MTL::ComputePipelineDescriptor`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use crate::enums::ShaderValidation;
12use crate::types::Size;
13
14use super::PipelineBufferDescriptorArray;
15
16pub struct ComputePipelineDescriptor(pub(crate) NonNull<c_void>);
17
18impl ComputePipelineDescriptor {
19    /// Allocate a new compute pipeline descriptor.
20    ///
21    /// C++ equivalent: `static ComputePipelineDescriptor* alloc()`
22    pub fn alloc() -> Option<Self> {
23        unsafe {
24            let class = mtl_sys::class!(MTLComputePipelineDescriptor);
25            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
26            Self::from_raw(ptr)
27        }
28    }
29
30    /// Initialize the descriptor.
31    ///
32    /// C++ equivalent: `ComputePipelineDescriptor* init()`
33    pub fn init(self) -> Option<Self> {
34        unsafe {
35            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(init));
36            std::mem::forget(self);
37            Self::from_raw(ptr)
38        }
39    }
40
41    /// Create a new compute pipeline descriptor.
42    pub fn new() -> Option<Self> {
43        Self::alloc().and_then(|d| d.init())
44    }
45
46    /// Create from a raw pointer.
47    ///
48    /// # Safety
49    ///
50    /// The pointer must be a valid compute pipeline descriptor object.
51    #[inline]
52    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
53        NonNull::new(ptr).map(Self)
54    }
55
56    /// Get the raw pointer.
57    #[inline]
58    pub fn as_raw(&self) -> *mut c_void {
59        self.0.as_ptr()
60    }
61
62    /// Reset the descriptor to default values.
63    ///
64    /// C++ equivalent: `void reset()`
65    #[inline]
66    pub fn reset(&self) {
67        unsafe {
68            msg_send_0::<()>(self.as_ptr(), sel!(reset));
69        }
70    }
71
72    // =========================================================================
73    // Basic Properties
74    // =========================================================================
75
76    /// Get the label.
77    ///
78    /// C++ equivalent: `NS::String* label() const`
79    pub fn label(&self) -> Option<String> {
80        unsafe {
81            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
82            if ptr.is_null() {
83                return None;
84            }
85            let utf8_ptr: *const std::ffi::c_char =
86                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
87            if utf8_ptr.is_null() {
88                return None;
89            }
90            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
91            Some(c_str.to_string_lossy().into_owned())
92        }
93    }
94
95    /// Set the label.
96    ///
97    /// C++ equivalent: `void setLabel(const NS::String* label)`
98    pub fn set_label(&self, label: &str) {
99        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
100            unsafe {
101                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
102            }
103        }
104    }
105
106    // =========================================================================
107    // Compute Function
108    // =========================================================================
109
110    /// Get the compute function.
111    ///
112    /// C++ equivalent: `Function* computeFunction() const`
113    pub fn compute_function(&self) -> Option<crate::Function> {
114        unsafe {
115            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(computeFunction));
116            if ptr.is_null() {
117                return None;
118            }
119            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
120            crate::Function::from_raw(ptr)
121        }
122    }
123
124    /// Set the compute function.
125    ///
126    /// C++ equivalent: `void setComputeFunction(const MTL::Function* computeFunction)`
127    pub fn set_compute_function(&self, function: Option<&crate::Function>) {
128        unsafe {
129            let ptr = function.map_or(std::ptr::null(), |f| f.as_ptr());
130            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setComputeFunction:), ptr);
131        }
132    }
133
134    // =========================================================================
135    // Threadgroup Configuration
136    // =========================================================================
137
138    /// Get the maximum total threads per threadgroup.
139    ///
140    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerThreadgroup() const`
141    #[inline]
142    pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
143        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
144    }
145
146    /// Set the maximum total threads per threadgroup.
147    ///
148    /// C++ equivalent: `void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup)`
149    #[inline]
150    pub fn set_max_total_threads_per_threadgroup(&self, count: UInteger) {
151        unsafe {
152            msg_send_1::<(), UInteger>(
153                self.as_ptr(),
154                sel!(setMaxTotalThreadsPerThreadgroup:),
155                count,
156            );
157        }
158    }
159
160    /// Get the required threads per threadgroup.
161    ///
162    /// C++ equivalent: `Size requiredThreadsPerThreadgroup() const`
163    #[inline]
164    pub fn required_threads_per_threadgroup(&self) -> Size {
165        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
166    }
167
168    /// Set the required threads per threadgroup.
169    ///
170    /// C++ equivalent: `void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup)`
171    #[inline]
172    pub fn set_required_threads_per_threadgroup(&self, size: Size) {
173        unsafe {
174            msg_send_1::<(), Size>(self.as_ptr(), sel!(setRequiredThreadsPerThreadgroup:), size);
175        }
176    }
177
178    /// Check if thread group size is multiple of thread execution width.
179    ///
180    /// C++ equivalent: `bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const`
181    #[inline]
182    pub fn thread_group_size_is_multiple_of_thread_execution_width(&self) -> bool {
183        unsafe {
184            msg_send_0(
185                self.as_ptr(),
186                sel!(threadGroupSizeIsMultipleOfThreadExecutionWidth),
187            )
188        }
189    }
190
191    /// Set thread group size is multiple of thread execution width.
192    ///
193    /// C++ equivalent: `void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth)`
194    #[inline]
195    pub fn set_thread_group_size_is_multiple_of_thread_execution_width(&self, value: bool) {
196        unsafe {
197            msg_send_1::<(), bool>(
198                self.as_ptr(),
199                sel!(setThreadGroupSizeIsMultipleOfThreadExecutionWidth:),
200                value,
201            );
202        }
203    }
204
205    // =========================================================================
206    // Call Stack Depth
207    // =========================================================================
208
209    /// Get the maximum call stack depth.
210    ///
211    /// C++ equivalent: `NS::UInteger maxCallStackDepth() const`
212    #[inline]
213    pub fn max_call_stack_depth(&self) -> UInteger {
214        unsafe { msg_send_0(self.as_ptr(), sel!(maxCallStackDepth)) }
215    }
216
217    /// Set the maximum call stack depth.
218    ///
219    /// C++ equivalent: `void setMaxCallStackDepth(NS::UInteger maxCallStackDepth)`
220    #[inline]
221    pub fn set_max_call_stack_depth(&self, depth: UInteger) {
222        unsafe {
223            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setMaxCallStackDepth:), depth);
224        }
225    }
226
227    // =========================================================================
228    // Indirect Command Buffers
229    // =========================================================================
230
231    /// Check if the pipeline supports indirect command buffers.
232    ///
233    /// C++ equivalent: `bool supportIndirectCommandBuffers() const`
234    #[inline]
235    pub fn support_indirect_command_buffers(&self) -> bool {
236        unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
237    }
238
239    /// Set indirect command buffer support.
240    ///
241    /// C++ equivalent: `void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers)`
242    #[inline]
243    pub fn set_support_indirect_command_buffers(&self, support: bool) {
244        unsafe {
245            msg_send_1::<(), bool>(
246                self.as_ptr(),
247                sel!(setSupportIndirectCommandBuffers:),
248                support,
249            );
250        }
251    }
252
253    /// Check if support adding binary functions is enabled.
254    ///
255    /// C++ equivalent: `bool supportAddingBinaryFunctions() const`
256    #[inline]
257    pub fn support_adding_binary_functions(&self) -> bool {
258        unsafe { msg_send_0(self.as_ptr(), sel!(supportAddingBinaryFunctions)) }
259    }
260
261    /// Set support adding binary functions.
262    ///
263    /// C++ equivalent: `void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions)`
264    #[inline]
265    pub fn set_support_adding_binary_functions(&self, support: bool) {
266        unsafe {
267            msg_send_1::<(), bool>(
268                self.as_ptr(),
269                sel!(setSupportAddingBinaryFunctions:),
270                support,
271            );
272        }
273    }
274
275    // =========================================================================
276    // Shader Validation
277    // =========================================================================
278
279    /// Get the shader validation mode.
280    ///
281    /// C++ equivalent: `ShaderValidation shaderValidation() const`
282    #[inline]
283    pub fn shader_validation(&self) -> ShaderValidation {
284        unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
285    }
286
287    /// Set the shader validation mode.
288    ///
289    /// C++ equivalent: `void setShaderValidation(MTL::ShaderValidation shaderValidation)`
290    #[inline]
291    pub fn set_shader_validation(&self, validation: ShaderValidation) {
292        unsafe {
293            msg_send_1::<(), ShaderValidation>(
294                self.as_ptr(),
295                sel!(setShaderValidation:),
296                validation,
297            );
298        }
299    }
300
301    // =========================================================================
302    // Buffer Descriptors
303    // =========================================================================
304
305    /// Get the buffer descriptors array.
306    ///
307    /// C++ equivalent: `PipelineBufferDescriptorArray* buffers() const`
308    pub fn buffers(&self) -> Option<PipelineBufferDescriptorArray> {
309        unsafe {
310            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(buffers));
311            if ptr.is_null() {
312                return None;
313            }
314            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
315            PipelineBufferDescriptorArray::from_raw(ptr)
316        }
317    }
318
319    // =========================================================================
320    // Stage Input Descriptor
321    // =========================================================================
322
323    /// Get the stage input descriptor.
324    ///
325    /// C++ equivalent: `StageInputOutputDescriptor* stageInputDescriptor() const`
326    pub fn stage_input_descriptor_raw(&self) -> *mut c_void {
327        unsafe { msg_send_0(self.as_ptr(), sel!(stageInputDescriptor)) }
328    }
329
330    /// Set the stage input descriptor.
331    ///
332    /// C++ equivalent: `void setStageInputDescriptor(const StageInputOutputDescriptor*)`
333    ///
334    /// # Safety
335    ///
336    /// The pointer must be a valid StageInputOutputDescriptor object.
337    pub unsafe fn set_stage_input_descriptor_raw(&self, descriptor: *const c_void) {
338        unsafe {
339            msg_send_1::<(), *const c_void>(
340                self.as_ptr(),
341                sel!(setStageInputDescriptor:),
342                descriptor,
343            );
344        }
345    }
346
347    // =========================================================================
348    // Linked Functions
349    // =========================================================================
350
351    /// Get the linked functions.
352    ///
353    /// C++ equivalent: `LinkedFunctions* linkedFunctions() const`
354    pub fn linked_functions(&self) -> Option<crate::LinkedFunctions> {
355        unsafe {
356            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(linkedFunctions));
357            if ptr.is_null() {
358                return None;
359            }
360            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
361            crate::LinkedFunctions::from_raw(ptr)
362        }
363    }
364
365    /// Set the linked functions.
366    ///
367    /// C++ equivalent: `void setLinkedFunctions(const LinkedFunctions*)`
368    pub fn set_linked_functions(&self, functions: Option<&crate::LinkedFunctions>) {
369        let ptr = functions.map(|f| f.as_ptr()).unwrap_or(std::ptr::null());
370        unsafe {
371            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLinkedFunctions:), ptr);
372        }
373    }
374
375    // =========================================================================
376    // Binary Archives
377    // =========================================================================
378
379    /// Get the binary archives (raw NSArray pointer).
380    ///
381    /// C++ equivalent: `NS::Array* binaryArchives() const`
382    pub fn binary_archives_raw(&self) -> *mut c_void {
383        unsafe { msg_send_0(self.as_ptr(), sel!(binaryArchives)) }
384    }
385
386    /// Set the binary archives.
387    ///
388    /// C++ equivalent: `void setBinaryArchives(const NS::Array*)`
389    ///
390    /// # Safety
391    ///
392    /// The pointer must be a valid NSArray of BinaryArchive objects.
393    pub unsafe fn set_binary_archives_raw(&self, archives: *const c_void) {
394        unsafe {
395            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setBinaryArchives:), archives);
396        }
397    }
398
399    // =========================================================================
400    // Preloaded Libraries
401    // =========================================================================
402
403    /// Get the preloaded libraries (raw NSArray pointer).
404    ///
405    /// C++ equivalent: `NS::Array* preloadedLibraries() const`
406    pub fn preloaded_libraries_raw(&self) -> *mut c_void {
407        unsafe { msg_send_0(self.as_ptr(), sel!(preloadedLibraries)) }
408    }
409
410    /// Set the preloaded libraries.
411    ///
412    /// C++ equivalent: `void setPreloadedLibraries(const NS::Array*)`
413    ///
414    /// # Safety
415    ///
416    /// The pointer must be a valid NSArray of DynamicLibrary objects.
417    pub unsafe fn set_preloaded_libraries_raw(&self, libraries: *const c_void) {
418        unsafe {
419            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setPreloadedLibraries:), libraries);
420        }
421    }
422
423    // =========================================================================
424    // Insert Libraries
425    // =========================================================================
426
427    /// Get the insert libraries (raw NSArray pointer).
428    ///
429    /// C++ equivalent: `NS::Array* insertLibraries() const`
430    pub fn insert_libraries_raw(&self) -> *mut c_void {
431        unsafe { msg_send_0(self.as_ptr(), sel!(insertLibraries)) }
432    }
433
434    /// Set the insert libraries.
435    ///
436    /// C++ equivalent: `void setInsertLibraries(const NS::Array*)`
437    ///
438    /// # Safety
439    ///
440    /// The pointer must be a valid NSArray of DynamicLibrary objects.
441    pub unsafe fn set_insert_libraries_raw(&self, libraries: *const c_void) {
442        unsafe {
443            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setInsertLibraries:), libraries);
444        }
445    }
446}
447
448impl Clone for ComputePipelineDescriptor {
449    fn clone(&self) -> Self {
450        unsafe {
451            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
452            Self::from_raw(ptr).expect("copy returned null")
453        }
454    }
455}
456
457impl Drop for ComputePipelineDescriptor {
458    fn drop(&mut self) {
459        unsafe {
460            msg_send_0::<()>(self.as_ptr(), sel!(release));
461        }
462    }
463}
464
465impl Default for ComputePipelineDescriptor {
466    fn default() -> Self {
467        Self::new().expect("failed to create compute pipeline descriptor")
468    }
469}
470
471impl Referencing for ComputePipelineDescriptor {
472    #[inline]
473    fn as_ptr(&self) -> *const c_void {
474        self.0.as_ptr()
475    }
476}
477
478unsafe impl Send for ComputePipelineDescriptor {}
479unsafe impl Sync for ComputePipelineDescriptor {}
480
481impl std::fmt::Debug for ComputePipelineDescriptor {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        f.debug_struct("ComputePipelineDescriptor")
484            .field("label", &self.label())
485            .field(
486                "max_total_threads_per_threadgroup",
487                &self.max_total_threads_per_threadgroup(),
488            )
489            .finish()
490    }
491}