Skip to main content

mtl_gpu/mtl4/
compute_pipeline.rs

1//! MTL4 ComputePipeline implementation.
2//!
3//! Corresponds to `Metal/MTL4ComputePipeline.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use super::enums::IndirectCommandBufferSupportState;
12use super::{FunctionDescriptor, PipelineOptions, StaticLinkingDescriptor};
13use crate::Size;
14
15// ============================================================
16// ComputePipelineDescriptor
17// ============================================================
18
19/// Descriptor for MTL4 compute pipelines.
20///
21/// C++ equivalent: `MTL4::ComputePipelineDescriptor`
22///
23/// ComputePipelineDescriptor configures the compute function and
24/// threadgroup settings for a compute pipeline.
25#[repr(transparent)]
26pub struct ComputePipelineDescriptor(NonNull<c_void>);
27
28impl ComputePipelineDescriptor {
29    /// Create a ComputePipelineDescriptor from a raw pointer.
30    #[inline]
31    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
32        NonNull::new(ptr).map(Self)
33    }
34
35    /// Get the raw pointer.
36    #[inline]
37    pub fn as_raw(&self) -> *mut c_void {
38        self.0.as_ptr()
39    }
40
41    /// Create a new compute pipeline descriptor.
42    pub fn new() -> Option<Self> {
43        unsafe {
44            let class = mtl_sys::Class::get("MTL4ComputePipelineDescriptor")?;
45            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
46            if ptr.is_null() {
47                return None;
48            }
49            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
50            Self::from_raw(ptr)
51        }
52    }
53
54    // ========== Base Pipeline Properties ==========
55
56    /// Get the label.
57    ///
58    /// C++ equivalent: `NS::String* label() const`
59    pub fn label(&self) -> Option<String> {
60        unsafe {
61            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
62            if ns_string.is_null() {
63                return None;
64            }
65            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
66            if c_str.is_null() {
67                return None;
68            }
69            Some(
70                std::ffi::CStr::from_ptr(c_str)
71                    .to_string_lossy()
72                    .into_owned(),
73            )
74        }
75    }
76
77    /// Set the label.
78    ///
79    /// C++ equivalent: `void setLabel(const NS::String*)`
80    pub fn set_label(&self, label: &str) {
81        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
82            unsafe {
83                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
84            }
85        }
86    }
87
88    /// Get the pipeline options.
89    ///
90    /// C++ equivalent: `PipelineOptions* options() const`
91    pub fn options(&self) -> Option<PipelineOptions> {
92        unsafe {
93            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(options));
94            PipelineOptions::from_raw(ptr)
95        }
96    }
97
98    /// Set the pipeline options.
99    ///
100    /// C++ equivalent: `void setOptions(const MTL4::PipelineOptions*)`
101    pub fn set_options(&self, options: &PipelineOptions) {
102        unsafe {
103            let _: () = msg_send_1(self.as_ptr(), sel!(setOptions:), options.as_ptr());
104        }
105    }
106
107    // ========== Compute-Specific Properties ==========
108
109    /// Get the compute function descriptor.
110    ///
111    /// C++ equivalent: `FunctionDescriptor* computeFunctionDescriptor() const`
112    pub fn compute_function_descriptor(&self) -> Option<FunctionDescriptor> {
113        unsafe {
114            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(computeFunctionDescriptor));
115            FunctionDescriptor::from_raw(ptr)
116        }
117    }
118
119    /// Set the compute function descriptor.
120    ///
121    /// C++ equivalent: `void setComputeFunctionDescriptor(const MTL4::FunctionDescriptor*)`
122    pub fn set_compute_function_descriptor(&self, descriptor: &FunctionDescriptor) {
123        unsafe {
124            let _: () = msg_send_1(
125                self.as_ptr(),
126                sel!(setComputeFunctionDescriptor:),
127                descriptor.as_ptr(),
128            );
129        }
130    }
131
132    /// Get the maximum total threads per threadgroup.
133    ///
134    /// C++ equivalent: `NS::UInteger maxTotalThreadsPerThreadgroup() const`
135    pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
136        unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
137    }
138
139    /// Set the maximum total threads per threadgroup.
140    ///
141    /// C++ equivalent: `void setMaxTotalThreadsPerThreadgroup(NS::UInteger)`
142    pub fn set_max_total_threads_per_threadgroup(&self, max_threads: UInteger) {
143        unsafe {
144            let _: () = msg_send_1(
145                self.as_ptr(),
146                sel!(setMaxTotalThreadsPerThreadgroup:),
147                max_threads,
148            );
149        }
150    }
151
152    /// Get the required threads per threadgroup.
153    ///
154    /// C++ equivalent: `MTL::Size requiredThreadsPerThreadgroup() const`
155    pub fn required_threads_per_threadgroup(&self) -> Size {
156        unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
157    }
158
159    /// Set the required threads per threadgroup.
160    ///
161    /// C++ equivalent: `void setRequiredThreadsPerThreadgroup(MTL::Size)`
162    pub fn set_required_threads_per_threadgroup(&self, size: Size) {
163        unsafe {
164            let _: () = msg_send_1(self.as_ptr(), sel!(setRequiredThreadsPerThreadgroup:), size);
165        }
166    }
167
168    /// Get the static linking descriptor.
169    ///
170    /// C++ equivalent: `StaticLinkingDescriptor* staticLinkingDescriptor() const`
171    pub fn static_linking_descriptor(&self) -> Option<StaticLinkingDescriptor> {
172        unsafe {
173            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(staticLinkingDescriptor));
174            StaticLinkingDescriptor::from_raw(ptr)
175        }
176    }
177
178    /// Set the static linking descriptor.
179    ///
180    /// C++ equivalent: `void setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor*)`
181    pub fn set_static_linking_descriptor(&self, descriptor: &StaticLinkingDescriptor) {
182        unsafe {
183            let _: () = msg_send_1(
184                self.as_ptr(),
185                sel!(setStaticLinkingDescriptor:),
186                descriptor.as_ptr(),
187            );
188        }
189    }
190
191    /// Get whether binary linking is supported.
192    ///
193    /// C++ equivalent: `bool supportBinaryLinking() const`
194    pub fn support_binary_linking(&self) -> bool {
195        unsafe { msg_send_0(self.as_ptr(), sel!(supportBinaryLinking)) }
196    }
197
198    /// Set whether binary linking is supported.
199    ///
200    /// C++ equivalent: `void setSupportBinaryLinking(bool)`
201    pub fn set_support_binary_linking(&self, support: bool) {
202        unsafe {
203            let _: () = msg_send_1(self.as_ptr(), sel!(setSupportBinaryLinking:), support);
204        }
205    }
206
207    /// Get the indirect command buffer support state.
208    ///
209    /// C++ equivalent: `IndirectCommandBufferSupportState supportIndirectCommandBuffers() const`
210    pub fn support_indirect_command_buffers(&self) -> IndirectCommandBufferSupportState {
211        unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
212    }
213
214    /// Set the indirect command buffer support state.
215    ///
216    /// C++ equivalent: `void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState)`
217    pub fn set_support_indirect_command_buffers(&self, state: IndirectCommandBufferSupportState) {
218        unsafe {
219            let _: () = msg_send_1(
220                self.as_ptr(),
221                sel!(setSupportIndirectCommandBuffers:),
222                state,
223            );
224        }
225    }
226
227    /// Get whether threadgroup size is a multiple of thread execution width.
228    ///
229    /// C++ equivalent: `bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const`
230    pub fn thread_group_size_is_multiple_of_thread_execution_width(&self) -> bool {
231        unsafe {
232            msg_send_0(
233                self.as_ptr(),
234                sel!(threadGroupSizeIsMultipleOfThreadExecutionWidth),
235            )
236        }
237    }
238
239    /// Set whether threadgroup size is a multiple of thread execution width.
240    ///
241    /// C++ equivalent: `void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool)`
242    pub fn set_thread_group_size_is_multiple_of_thread_execution_width(&self, value: bool) {
243        unsafe {
244            let _: () = msg_send_1(
245                self.as_ptr(),
246                sel!(setThreadGroupSizeIsMultipleOfThreadExecutionWidth:),
247                value,
248            );
249        }
250    }
251
252    /// Reset the descriptor to its default state.
253    ///
254    /// C++ equivalent: `void reset()`
255    pub fn reset(&self) {
256        unsafe {
257            let _: () = msg_send_0(self.as_ptr(), sel!(reset));
258        }
259    }
260}
261
262impl Clone for ComputePipelineDescriptor {
263    fn clone(&self) -> Self {
264        unsafe {
265            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
266        }
267        Self(self.0)
268    }
269}
270
271impl Drop for ComputePipelineDescriptor {
272    fn drop(&mut self) {
273        unsafe {
274            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
275        }
276    }
277}
278
279impl Referencing for ComputePipelineDescriptor {
280    #[inline]
281    fn as_ptr(&self) -> *const c_void {
282        self.0.as_ptr()
283    }
284}
285
286unsafe impl Send for ComputePipelineDescriptor {}
287unsafe impl Sync for ComputePipelineDescriptor {}
288
289impl std::fmt::Debug for ComputePipelineDescriptor {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("ComputePipelineDescriptor")
292            .field("label", &self.label())
293            .field(
294                "max_total_threads_per_threadgroup",
295                &self.max_total_threads_per_threadgroup(),
296            )
297            .field(
298                "required_threads_per_threadgroup",
299                &self.required_threads_per_threadgroup(),
300            )
301            .finish()
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_compute_pipeline_descriptor_size() {
311        assert_eq!(
312            std::mem::size_of::<ComputePipelineDescriptor>(),
313            std::mem::size_of::<*mut c_void>()
314        );
315    }
316}