mtl_gpu/mtl4/
compute_pipeline.rs1use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use super::enums::IndirectCommandBufferSupportState;
12use super::{FunctionDescriptor, PipelineOptions, StaticLinkingDescriptor};
13use crate::Size;
14
15#[repr(transparent)]
26pub struct ComputePipelineDescriptor(NonNull<c_void>);
27
28impl ComputePipelineDescriptor {
29 #[inline]
31 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
32 NonNull::new(ptr).map(Self)
33 }
34
35 #[inline]
37 pub fn as_raw(&self) -> *mut c_void {
38 self.0.as_ptr()
39 }
40
41 pub fn new() -> Option<Self> {
43 unsafe {
44 let class = mtl_sys::Class::get("MTL4ComputePipelineDescriptor")?;
45 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
46 if ptr.is_null() {
47 return None;
48 }
49 let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
50 Self::from_raw(ptr)
51 }
52 }
53
54 pub fn label(&self) -> Option<String> {
60 unsafe {
61 let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
62 if ns_string.is_null() {
63 return None;
64 }
65 let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
66 if c_str.is_null() {
67 return None;
68 }
69 Some(
70 std::ffi::CStr::from_ptr(c_str)
71 .to_string_lossy()
72 .into_owned(),
73 )
74 }
75 }
76
77 pub fn set_label(&self, label: &str) {
81 if let Some(ns_label) = mtl_foundation::String::from_str(label) {
82 unsafe {
83 let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
84 }
85 }
86 }
87
88 pub fn options(&self) -> Option<PipelineOptions> {
92 unsafe {
93 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(options));
94 PipelineOptions::from_raw(ptr)
95 }
96 }
97
98 pub fn set_options(&self, options: &PipelineOptions) {
102 unsafe {
103 let _: () = msg_send_1(self.as_ptr(), sel!(setOptions:), options.as_ptr());
104 }
105 }
106
107 pub fn compute_function_descriptor(&self) -> Option<FunctionDescriptor> {
113 unsafe {
114 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(computeFunctionDescriptor));
115 FunctionDescriptor::from_raw(ptr)
116 }
117 }
118
119 pub fn set_compute_function_descriptor(&self, descriptor: &FunctionDescriptor) {
123 unsafe {
124 let _: () = msg_send_1(
125 self.as_ptr(),
126 sel!(setComputeFunctionDescriptor:),
127 descriptor.as_ptr(),
128 );
129 }
130 }
131
132 pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
136 unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
137 }
138
139 pub fn set_max_total_threads_per_threadgroup(&self, max_threads: UInteger) {
143 unsafe {
144 let _: () = msg_send_1(
145 self.as_ptr(),
146 sel!(setMaxTotalThreadsPerThreadgroup:),
147 max_threads,
148 );
149 }
150 }
151
152 pub fn required_threads_per_threadgroup(&self) -> Size {
156 unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
157 }
158
159 pub fn set_required_threads_per_threadgroup(&self, size: Size) {
163 unsafe {
164 let _: () = msg_send_1(self.as_ptr(), sel!(setRequiredThreadsPerThreadgroup:), size);
165 }
166 }
167
168 pub fn static_linking_descriptor(&self) -> Option<StaticLinkingDescriptor> {
172 unsafe {
173 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(staticLinkingDescriptor));
174 StaticLinkingDescriptor::from_raw(ptr)
175 }
176 }
177
178 pub fn set_static_linking_descriptor(&self, descriptor: &StaticLinkingDescriptor) {
182 unsafe {
183 let _: () = msg_send_1(
184 self.as_ptr(),
185 sel!(setStaticLinkingDescriptor:),
186 descriptor.as_ptr(),
187 );
188 }
189 }
190
191 pub fn support_binary_linking(&self) -> bool {
195 unsafe { msg_send_0(self.as_ptr(), sel!(supportBinaryLinking)) }
196 }
197
198 pub fn set_support_binary_linking(&self, support: bool) {
202 unsafe {
203 let _: () = msg_send_1(self.as_ptr(), sel!(setSupportBinaryLinking:), support);
204 }
205 }
206
207 pub fn support_indirect_command_buffers(&self) -> IndirectCommandBufferSupportState {
211 unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
212 }
213
214 pub fn set_support_indirect_command_buffers(&self, state: IndirectCommandBufferSupportState) {
218 unsafe {
219 let _: () = msg_send_1(
220 self.as_ptr(),
221 sel!(setSupportIndirectCommandBuffers:),
222 state,
223 );
224 }
225 }
226
227 pub fn thread_group_size_is_multiple_of_thread_execution_width(&self) -> bool {
231 unsafe {
232 msg_send_0(
233 self.as_ptr(),
234 sel!(threadGroupSizeIsMultipleOfThreadExecutionWidth),
235 )
236 }
237 }
238
239 pub fn set_thread_group_size_is_multiple_of_thread_execution_width(&self, value: bool) {
243 unsafe {
244 let _: () = msg_send_1(
245 self.as_ptr(),
246 sel!(setThreadGroupSizeIsMultipleOfThreadExecutionWidth:),
247 value,
248 );
249 }
250 }
251
252 pub fn reset(&self) {
256 unsafe {
257 let _: () = msg_send_0(self.as_ptr(), sel!(reset));
258 }
259 }
260}
261
262impl Clone for ComputePipelineDescriptor {
263 fn clone(&self) -> Self {
264 unsafe {
265 mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
266 }
267 Self(self.0)
268 }
269}
270
271impl Drop for ComputePipelineDescriptor {
272 fn drop(&mut self) {
273 unsafe {
274 mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
275 }
276 }
277}
278
279impl Referencing for ComputePipelineDescriptor {
280 #[inline]
281 fn as_ptr(&self) -> *const c_void {
282 self.0.as_ptr()
283 }
284}
285
286unsafe impl Send for ComputePipelineDescriptor {}
287unsafe impl Sync for ComputePipelineDescriptor {}
288
289impl std::fmt::Debug for ComputePipelineDescriptor {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("ComputePipelineDescriptor")
292 .field("label", &self.label())
293 .field(
294 "max_total_threads_per_threadgroup",
295 &self.max_total_threads_per_threadgroup(),
296 )
297 .field(
298 "required_threads_per_threadgroup",
299 &self.required_threads_per_threadgroup(),
300 )
301 .finish()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_compute_pipeline_descriptor_size() {
311 assert_eq!(
312 std::mem::size_of::<ComputePipelineDescriptor>(),
313 std::mem::size_of::<*mut c_void>()
314 );
315 }
316}