mtl_gpu/pipeline/compute_descriptor.rs
1//! Compute pipeline descriptor.
2//!
3//! Corresponds to `MTL::ComputePipelineDescriptor`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use crate::enums::ShaderValidation;
12use crate::types::Size;
13
14use super::PipelineBufferDescriptorArray;
15
16pub struct ComputePipelineDescriptor(pub(crate) NonNull<c_void>);
17
18impl ComputePipelineDescriptor {
19 /// Allocate a new compute pipeline descriptor.
20 ///
21 /// C++ equivalent: `static ComputePipelineDescriptor* alloc()`
22 pub fn alloc() -> Option<Self> {
23 unsafe {
24 let class = mtl_sys::class!(MTLComputePipelineDescriptor);
25 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
26 Self::from_raw(ptr)
27 }
28 }
29
30 /// Initialize the descriptor.
31 ///
32 /// C++ equivalent: `ComputePipelineDescriptor* init()`
33 pub fn init(self) -> Option<Self> {
34 unsafe {
35 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(init));
36 std::mem::forget(self);
37 Self::from_raw(ptr)
38 }
39 }
40
41 /// Create a new compute pipeline descriptor.
42 pub fn new() -> Option<Self> {
43 Self::alloc().and_then(|d| d.init())
44 }
45
46 /// Create from a raw pointer.
47 ///
48 /// # Safety
49 ///
50 /// The pointer must be a valid compute pipeline descriptor object.
51 #[inline]
52 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
53 NonNull::new(ptr).map(Self)
54 }
55
56 /// Get the raw pointer.
57 #[inline]
58 pub fn as_raw(&self) -> *mut c_void {
59 self.0.as_ptr()
60 }
61
62 /// Reset the descriptor to default values.
63 ///
64 /// C++ equivalent: `void reset()`
65 #[inline]
66 pub fn reset(&self) {
67 unsafe {
68 msg_send_0::<()>(self.as_ptr(), sel!(reset));
69 }
70 }
71
72 // =========================================================================
73 // Basic Properties
74 // =========================================================================
75
76 /// Get the label.
77 ///
78 /// C++ equivalent: `NS::String* label() const`
79 pub fn label(&self) -> Option<String> {
80 unsafe {
81 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
82 if ptr.is_null() {
83 return None;
84 }
85 let utf8_ptr: *const std::ffi::c_char =
86 mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
87 if utf8_ptr.is_null() {
88 return None;
89 }
90 let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
91 Some(c_str.to_string_lossy().into_owned())
92 }
93 }
94
95 /// Set the label.
96 ///
97 /// C++ equivalent: `void setLabel(const NS::String* label)`
98 pub fn set_label(&self, label: &str) {
99 if let Some(ns_label) = mtl_foundation::String::from_str(label) {
100 unsafe {
101 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
102 }
103 }
104 }
105
106 // =========================================================================
107 // Compute Function
108 // =========================================================================
109
110 /// Get the compute function.
111 ///
112 /// C++ equivalent: `Function* computeFunction() const`
113 pub fn compute_function(&self) -> Option<crate::Function> {
114 unsafe {
115 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(computeFunction));
116 if ptr.is_null() {
117 return None;
118 }
119 let _: *mut c_void = msg_send_0(ptr, sel!(retain));
120 crate::Function::from_raw(ptr)
121 }
122 }
123
124 /// Set the compute function.
125 ///
126 /// C++ equivalent: `void setComputeFunction(const MTL::Function* computeFunction)`
127 pub fn set_compute_function(&self, function: Option<&crate::Function>) {
128 unsafe {
129 let ptr = function.map_or(std::ptr::null(), |f| f.as_ptr());
130 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setComputeFunction:), ptr);
131 }
132 }
133
134 // =========================================================================
135 // Threadgroup Configuration
136 // =========================================================================
137
138 /// Get the maximum total threads per threadgroup.
139 ///
140 /// C++ equivalent: `NS::UInteger maxTotalThreadsPerThreadgroup() const`
141 #[inline]
142 pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
143 unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
144 }
145
146 /// Set the maximum total threads per threadgroup.
147 ///
148 /// C++ equivalent: `void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup)`
149 #[inline]
150 pub fn set_max_total_threads_per_threadgroup(&self, count: UInteger) {
151 unsafe {
152 msg_send_1::<(), UInteger>(
153 self.as_ptr(),
154 sel!(setMaxTotalThreadsPerThreadgroup:),
155 count,
156 );
157 }
158 }
159
160 /// Get the required threads per threadgroup.
161 ///
162 /// C++ equivalent: `Size requiredThreadsPerThreadgroup() const`
163 #[inline]
164 pub fn required_threads_per_threadgroup(&self) -> Size {
165 unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
166 }
167
168 /// Set the required threads per threadgroup.
169 ///
170 /// C++ equivalent: `void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup)`
171 #[inline]
172 pub fn set_required_threads_per_threadgroup(&self, size: Size) {
173 unsafe {
174 msg_send_1::<(), Size>(self.as_ptr(), sel!(setRequiredThreadsPerThreadgroup:), size);
175 }
176 }
177
178 /// Check if thread group size is multiple of thread execution width.
179 ///
180 /// C++ equivalent: `bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const`
181 #[inline]
182 pub fn thread_group_size_is_multiple_of_thread_execution_width(&self) -> bool {
183 unsafe {
184 msg_send_0(
185 self.as_ptr(),
186 sel!(threadGroupSizeIsMultipleOfThreadExecutionWidth),
187 )
188 }
189 }
190
191 /// Set thread group size is multiple of thread execution width.
192 ///
193 /// C++ equivalent: `void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth)`
194 #[inline]
195 pub fn set_thread_group_size_is_multiple_of_thread_execution_width(&self, value: bool) {
196 unsafe {
197 msg_send_1::<(), bool>(
198 self.as_ptr(),
199 sel!(setThreadGroupSizeIsMultipleOfThreadExecutionWidth:),
200 value,
201 );
202 }
203 }
204
205 // =========================================================================
206 // Call Stack Depth
207 // =========================================================================
208
209 /// Get the maximum call stack depth.
210 ///
211 /// C++ equivalent: `NS::UInteger maxCallStackDepth() const`
212 #[inline]
213 pub fn max_call_stack_depth(&self) -> UInteger {
214 unsafe { msg_send_0(self.as_ptr(), sel!(maxCallStackDepth)) }
215 }
216
217 /// Set the maximum call stack depth.
218 ///
219 /// C++ equivalent: `void setMaxCallStackDepth(NS::UInteger maxCallStackDepth)`
220 #[inline]
221 pub fn set_max_call_stack_depth(&self, depth: UInteger) {
222 unsafe {
223 msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setMaxCallStackDepth:), depth);
224 }
225 }
226
227 // =========================================================================
228 // Indirect Command Buffers
229 // =========================================================================
230
231 /// Check if the pipeline supports indirect command buffers.
232 ///
233 /// C++ equivalent: `bool supportIndirectCommandBuffers() const`
234 #[inline]
235 pub fn support_indirect_command_buffers(&self) -> bool {
236 unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
237 }
238
239 /// Set indirect command buffer support.
240 ///
241 /// C++ equivalent: `void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers)`
242 #[inline]
243 pub fn set_support_indirect_command_buffers(&self, support: bool) {
244 unsafe {
245 msg_send_1::<(), bool>(
246 self.as_ptr(),
247 sel!(setSupportIndirectCommandBuffers:),
248 support,
249 );
250 }
251 }
252
253 /// Check if support adding binary functions is enabled.
254 ///
255 /// C++ equivalent: `bool supportAddingBinaryFunctions() const`
256 #[inline]
257 pub fn support_adding_binary_functions(&self) -> bool {
258 unsafe { msg_send_0(self.as_ptr(), sel!(supportAddingBinaryFunctions)) }
259 }
260
261 /// Set support adding binary functions.
262 ///
263 /// C++ equivalent: `void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions)`
264 #[inline]
265 pub fn set_support_adding_binary_functions(&self, support: bool) {
266 unsafe {
267 msg_send_1::<(), bool>(
268 self.as_ptr(),
269 sel!(setSupportAddingBinaryFunctions:),
270 support,
271 );
272 }
273 }
274
275 // =========================================================================
276 // Shader Validation
277 // =========================================================================
278
279 /// Get the shader validation mode.
280 ///
281 /// C++ equivalent: `ShaderValidation shaderValidation() const`
282 #[inline]
283 pub fn shader_validation(&self) -> ShaderValidation {
284 unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
285 }
286
287 /// Set the shader validation mode.
288 ///
289 /// C++ equivalent: `void setShaderValidation(MTL::ShaderValidation shaderValidation)`
290 #[inline]
291 pub fn set_shader_validation(&self, validation: ShaderValidation) {
292 unsafe {
293 msg_send_1::<(), ShaderValidation>(
294 self.as_ptr(),
295 sel!(setShaderValidation:),
296 validation,
297 );
298 }
299 }
300
301 // =========================================================================
302 // Buffer Descriptors
303 // =========================================================================
304
305 /// Get the buffer descriptors array.
306 ///
307 /// C++ equivalent: `PipelineBufferDescriptorArray* buffers() const`
308 pub fn buffers(&self) -> Option<PipelineBufferDescriptorArray> {
309 unsafe {
310 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(buffers));
311 if ptr.is_null() {
312 return None;
313 }
314 let _: *mut c_void = msg_send_0(ptr, sel!(retain));
315 PipelineBufferDescriptorArray::from_raw(ptr)
316 }
317 }
318
319 // =========================================================================
320 // Stage Input Descriptor
321 // =========================================================================
322
323 /// Get the stage input descriptor.
324 ///
325 /// C++ equivalent: `StageInputOutputDescriptor* stageInputDescriptor() const`
326 pub fn stage_input_descriptor_raw(&self) -> *mut c_void {
327 unsafe { msg_send_0(self.as_ptr(), sel!(stageInputDescriptor)) }
328 }
329
330 /// Set the stage input descriptor.
331 ///
332 /// C++ equivalent: `void setStageInputDescriptor(const StageInputOutputDescriptor*)`
333 ///
334 /// # Safety
335 ///
336 /// The pointer must be a valid StageInputOutputDescriptor object.
337 pub unsafe fn set_stage_input_descriptor_raw(&self, descriptor: *const c_void) {
338 unsafe {
339 msg_send_1::<(), *const c_void>(
340 self.as_ptr(),
341 sel!(setStageInputDescriptor:),
342 descriptor,
343 );
344 }
345 }
346
347 // =========================================================================
348 // Linked Functions
349 // =========================================================================
350
351 /// Get the linked functions.
352 ///
353 /// C++ equivalent: `LinkedFunctions* linkedFunctions() const`
354 pub fn linked_functions(&self) -> Option<crate::LinkedFunctions> {
355 unsafe {
356 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(linkedFunctions));
357 if ptr.is_null() {
358 return None;
359 }
360 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
361 crate::LinkedFunctions::from_raw(ptr)
362 }
363 }
364
365 /// Set the linked functions.
366 ///
367 /// C++ equivalent: `void setLinkedFunctions(const LinkedFunctions*)`
368 pub fn set_linked_functions(&self, functions: Option<&crate::LinkedFunctions>) {
369 let ptr = functions.map(|f| f.as_ptr()).unwrap_or(std::ptr::null());
370 unsafe {
371 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLinkedFunctions:), ptr);
372 }
373 }
374
375 // =========================================================================
376 // Binary Archives
377 // =========================================================================
378
379 /// Get the binary archives (raw NSArray pointer).
380 ///
381 /// C++ equivalent: `NS::Array* binaryArchives() const`
382 pub fn binary_archives_raw(&self) -> *mut c_void {
383 unsafe { msg_send_0(self.as_ptr(), sel!(binaryArchives)) }
384 }
385
386 /// Set the binary archives.
387 ///
388 /// C++ equivalent: `void setBinaryArchives(const NS::Array*)`
389 ///
390 /// # Safety
391 ///
392 /// The pointer must be a valid NSArray of BinaryArchive objects.
393 pub unsafe fn set_binary_archives_raw(&self, archives: *const c_void) {
394 unsafe {
395 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setBinaryArchives:), archives);
396 }
397 }
398
399 // =========================================================================
400 // Preloaded Libraries
401 // =========================================================================
402
403 /// Get the preloaded libraries (raw NSArray pointer).
404 ///
405 /// C++ equivalent: `NS::Array* preloadedLibraries() const`
406 pub fn preloaded_libraries_raw(&self) -> *mut c_void {
407 unsafe { msg_send_0(self.as_ptr(), sel!(preloadedLibraries)) }
408 }
409
410 /// Set the preloaded libraries.
411 ///
412 /// C++ equivalent: `void setPreloadedLibraries(const NS::Array*)`
413 ///
414 /// # Safety
415 ///
416 /// The pointer must be a valid NSArray of DynamicLibrary objects.
417 pub unsafe fn set_preloaded_libraries_raw(&self, libraries: *const c_void) {
418 unsafe {
419 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setPreloadedLibraries:), libraries);
420 }
421 }
422
423 // =========================================================================
424 // Insert Libraries
425 // =========================================================================
426
427 /// Get the insert libraries (raw NSArray pointer).
428 ///
429 /// C++ equivalent: `NS::Array* insertLibraries() const`
430 pub fn insert_libraries_raw(&self) -> *mut c_void {
431 unsafe { msg_send_0(self.as_ptr(), sel!(insertLibraries)) }
432 }
433
434 /// Set the insert libraries.
435 ///
436 /// C++ equivalent: `void setInsertLibraries(const NS::Array*)`
437 ///
438 /// # Safety
439 ///
440 /// The pointer must be a valid NSArray of DynamicLibrary objects.
441 pub unsafe fn set_insert_libraries_raw(&self, libraries: *const c_void) {
442 unsafe {
443 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setInsertLibraries:), libraries);
444 }
445 }
446}
447
448impl Clone for ComputePipelineDescriptor {
449 fn clone(&self) -> Self {
450 unsafe {
451 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
452 Self::from_raw(ptr).expect("copy returned null")
453 }
454 }
455}
456
457impl Drop for ComputePipelineDescriptor {
458 fn drop(&mut self) {
459 unsafe {
460 msg_send_0::<()>(self.as_ptr(), sel!(release));
461 }
462 }
463}
464
465impl Default for ComputePipelineDescriptor {
466 fn default() -> Self {
467 Self::new().expect("failed to create compute pipeline descriptor")
468 }
469}
470
471impl Referencing for ComputePipelineDescriptor {
472 #[inline]
473 fn as_ptr(&self) -> *const c_void {
474 self.0.as_ptr()
475 }
476}
477
478unsafe impl Send for ComputePipelineDescriptor {}
479unsafe impl Sync for ComputePipelineDescriptor {}
480
481impl std::fmt::Debug for ComputePipelineDescriptor {
482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 f.debug_struct("ComputePipelineDescriptor")
484 .field("label", &self.label())
485 .field(
486 "max_total_threads_per_threadgroup",
487 &self.max_total_threads_per_threadgroup(),
488 )
489 .finish()
490 }
491}