Skip to main content

mtl_gpu/stage_input_output/
mod.rs

1//! Stage input/output descriptors.
2//!
3//! Corresponds to `Metal/MTLStageInputOutputDescriptor.hpp`.
4//!
5//! These descriptors are used to describe the layout of data passed between
6//! pipeline stages in compute kernels.
7
8use std::ffi::c_void;
9use std::ptr::NonNull;
10
11use mtl_foundation::{Referencing, UInteger};
12use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
13
14use crate::enums::{AttributeFormat, IndexType, StepFunction};
15
16// ============================================================================
17// BufferLayoutDescriptor
18// ============================================================================
19
20/// Describes the layout of data in a buffer for stage input/output.
21///
22/// C++ equivalent: `MTL::BufferLayoutDescriptor`
23#[repr(transparent)]
24pub struct BufferLayoutDescriptor(pub(crate) NonNull<c_void>);
25
26impl BufferLayoutDescriptor {
27    /// Create a new buffer layout descriptor.
28    ///
29    /// C++ equivalent: `BufferLayoutDescriptor::alloc()->init()`
30    pub fn new() -> Option<Self> {
31        unsafe {
32            let class = mtl_sys::class!(MTLBufferLayoutDescriptor);
33            let obj: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
34            if obj.is_null() {
35                return None;
36            }
37            let obj: *mut c_void = msg_send_0(obj, sel!(init));
38            Self::from_raw(obj)
39        }
40    }
41
42    /// Create from a raw pointer.
43    #[inline]
44    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
45        NonNull::new(ptr).map(Self)
46    }
47
48    /// Get the raw pointer.
49    #[inline]
50    pub fn as_raw(&self) -> *mut c_void {
51        self.0.as_ptr()
52    }
53
54    /// Get the step function.
55    ///
56    /// C++ equivalent: `StepFunction stepFunction() const`
57    #[inline]
58    pub fn step_function(&self) -> StepFunction {
59        unsafe { msg_send_0(self.as_ptr(), sel!(stepFunction)) }
60    }
61
62    /// Set the step function.
63    ///
64    /// C++ equivalent: `void setStepFunction(StepFunction)`
65    #[inline]
66    pub fn set_step_function(&self, step_function: StepFunction) {
67        unsafe {
68            msg_send_1::<(), StepFunction>(self.as_ptr(), sel!(setStepFunction:), step_function);
69        }
70    }
71
72    /// Get the step rate.
73    ///
74    /// C++ equivalent: `NS::UInteger stepRate() const`
75    #[inline]
76    pub fn step_rate(&self) -> UInteger {
77        unsafe { msg_send_0(self.as_ptr(), sel!(stepRate)) }
78    }
79
80    /// Set the step rate.
81    ///
82    /// C++ equivalent: `void setStepRate(NS::UInteger)`
83    #[inline]
84    pub fn set_step_rate(&self, step_rate: UInteger) {
85        unsafe {
86            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setStepRate:), step_rate);
87        }
88    }
89
90    /// Get the stride.
91    ///
92    /// C++ equivalent: `NS::UInteger stride() const`
93    #[inline]
94    pub fn stride(&self) -> UInteger {
95        unsafe { msg_send_0(self.as_ptr(), sel!(stride)) }
96    }
97
98    /// Set the stride.
99    ///
100    /// C++ equivalent: `void setStride(NS::UInteger)`
101    #[inline]
102    pub fn set_stride(&self, stride: UInteger) {
103        unsafe {
104            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setStride:), stride);
105        }
106    }
107}
108
109impl Default for BufferLayoutDescriptor {
110    fn default() -> Self {
111        Self::new().expect("failed to create BufferLayoutDescriptor")
112    }
113}
114
115impl Clone for BufferLayoutDescriptor {
116    fn clone(&self) -> Self {
117        unsafe {
118            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
119        }
120        Self(self.0)
121    }
122}
123
124impl Drop for BufferLayoutDescriptor {
125    fn drop(&mut self) {
126        unsafe {
127            msg_send_0::<()>(self.as_ptr(), sel!(release));
128        }
129    }
130}
131
132impl Referencing for BufferLayoutDescriptor {
133    #[inline]
134    fn as_ptr(&self) -> *const c_void {
135        self.0.as_ptr()
136    }
137}
138
139unsafe impl Send for BufferLayoutDescriptor {}
140unsafe impl Sync for BufferLayoutDescriptor {}
141
142// ============================================================================
143// BufferLayoutDescriptorArray
144// ============================================================================
145
146/// An array of buffer layout descriptors.
147///
148/// C++ equivalent: `MTL::BufferLayoutDescriptorArray`
149#[repr(transparent)]
150pub struct BufferLayoutDescriptorArray(pub(crate) NonNull<c_void>);
151
152impl BufferLayoutDescriptorArray {
153    /// Create from a raw pointer.
154    #[inline]
155    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
156        NonNull::new(ptr).map(Self)
157    }
158
159    /// Get the raw pointer.
160    #[inline]
161    pub fn as_raw(&self) -> *mut c_void {
162        self.0.as_ptr()
163    }
164
165    /// Get the descriptor at the specified index.
166    ///
167    /// C++ equivalent: `BufferLayoutDescriptor* object(NS::UInteger index)`
168    pub fn object_at(&self, index: UInteger) -> Option<BufferLayoutDescriptor> {
169        unsafe {
170            let ptr: *mut c_void =
171                msg_send_1(self.as_ptr(), sel!(objectAtIndexedSubscript:), index);
172            if ptr.is_null() {
173                return None;
174            }
175            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
176            BufferLayoutDescriptor::from_raw(ptr)
177        }
178    }
179
180    /// Set the descriptor at the specified index.
181    ///
182    /// C++ equivalent: `void setObject(const BufferLayoutDescriptor*, NS::UInteger)`
183    pub fn set_object_at(&self, descriptor: &BufferLayoutDescriptor, index: UInteger) {
184        unsafe {
185            msg_send_2::<(), *const c_void, UInteger>(
186                self.as_ptr(),
187                sel!(setObject: atIndexedSubscript:),
188                descriptor.as_ptr(),
189                index,
190            );
191        }
192    }
193}
194
195impl Clone for BufferLayoutDescriptorArray {
196    fn clone(&self) -> Self {
197        unsafe {
198            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
199        }
200        Self(self.0)
201    }
202}
203
204impl Drop for BufferLayoutDescriptorArray {
205    fn drop(&mut self) {
206        unsafe {
207            msg_send_0::<()>(self.as_ptr(), sel!(release));
208        }
209    }
210}
211
212impl Referencing for BufferLayoutDescriptorArray {
213    #[inline]
214    fn as_ptr(&self) -> *const c_void {
215        self.0.as_ptr()
216    }
217}
218
219unsafe impl Send for BufferLayoutDescriptorArray {}
220unsafe impl Sync for BufferLayoutDescriptorArray {}
221
222// ============================================================================
223// AttributeDescriptor
224// ============================================================================
225
226/// Describes an attribute in a stage input/output descriptor.
227///
228/// C++ equivalent: `MTL::AttributeDescriptor`
229#[repr(transparent)]
230pub struct AttributeDescriptor(pub(crate) NonNull<c_void>);
231
232impl AttributeDescriptor {
233    /// Create a new attribute descriptor.
234    ///
235    /// C++ equivalent: `AttributeDescriptor::alloc()->init()`
236    pub fn new() -> Option<Self> {
237        unsafe {
238            let class = mtl_sys::class!(MTLAttributeDescriptor);
239            let obj: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
240            if obj.is_null() {
241                return None;
242            }
243            let obj: *mut c_void = msg_send_0(obj, sel!(init));
244            Self::from_raw(obj)
245        }
246    }
247
248    /// Create from a raw pointer.
249    #[inline]
250    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
251        NonNull::new(ptr).map(Self)
252    }
253
254    /// Get the raw pointer.
255    #[inline]
256    pub fn as_raw(&self) -> *mut c_void {
257        self.0.as_ptr()
258    }
259
260    /// Get the format.
261    ///
262    /// C++ equivalent: `AttributeFormat format() const`
263    #[inline]
264    pub fn format(&self) -> AttributeFormat {
265        unsafe { msg_send_0(self.as_ptr(), sel!(format)) }
266    }
267
268    /// Set the format.
269    ///
270    /// C++ equivalent: `void setFormat(AttributeFormat)`
271    #[inline]
272    pub fn set_format(&self, format: AttributeFormat) {
273        unsafe {
274            msg_send_1::<(), AttributeFormat>(self.as_ptr(), sel!(setFormat:), format);
275        }
276    }
277
278    /// Get the offset.
279    ///
280    /// C++ equivalent: `NS::UInteger offset() const`
281    #[inline]
282    pub fn offset(&self) -> UInteger {
283        unsafe { msg_send_0(self.as_ptr(), sel!(offset)) }
284    }
285
286    /// Set the offset.
287    ///
288    /// C++ equivalent: `void setOffset(NS::UInteger)`
289    #[inline]
290    pub fn set_offset(&self, offset: UInteger) {
291        unsafe {
292            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setOffset:), offset);
293        }
294    }
295
296    /// Get the buffer index.
297    ///
298    /// C++ equivalent: `NS::UInteger bufferIndex() const`
299    #[inline]
300    pub fn buffer_index(&self) -> UInteger {
301        unsafe { msg_send_0(self.as_ptr(), sel!(bufferIndex)) }
302    }
303
304    /// Set the buffer index.
305    ///
306    /// C++ equivalent: `void setBufferIndex(NS::UInteger)`
307    #[inline]
308    pub fn set_buffer_index(&self, buffer_index: UInteger) {
309        unsafe {
310            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setBufferIndex:), buffer_index);
311        }
312    }
313}
314
315impl Default for AttributeDescriptor {
316    fn default() -> Self {
317        Self::new().expect("failed to create AttributeDescriptor")
318    }
319}
320
321impl Clone for AttributeDescriptor {
322    fn clone(&self) -> Self {
323        unsafe {
324            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
325        }
326        Self(self.0)
327    }
328}
329
330impl Drop for AttributeDescriptor {
331    fn drop(&mut self) {
332        unsafe {
333            msg_send_0::<()>(self.as_ptr(), sel!(release));
334        }
335    }
336}
337
338impl Referencing for AttributeDescriptor {
339    #[inline]
340    fn as_ptr(&self) -> *const c_void {
341        self.0.as_ptr()
342    }
343}
344
345unsafe impl Send for AttributeDescriptor {}
346unsafe impl Sync for AttributeDescriptor {}
347
348// ============================================================================
349// AttributeDescriptorArray
350// ============================================================================
351
352/// An array of attribute descriptors.
353///
354/// C++ equivalent: `MTL::AttributeDescriptorArray`
355#[repr(transparent)]
356pub struct AttributeDescriptorArray(pub(crate) NonNull<c_void>);
357
358impl AttributeDescriptorArray {
359    /// Create from a raw pointer.
360    #[inline]
361    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
362        NonNull::new(ptr).map(Self)
363    }
364
365    /// Get the raw pointer.
366    #[inline]
367    pub fn as_raw(&self) -> *mut c_void {
368        self.0.as_ptr()
369    }
370
371    /// Get the descriptor at the specified index.
372    ///
373    /// C++ equivalent: `AttributeDescriptor* object(NS::UInteger index)`
374    pub fn object_at(&self, index: UInteger) -> Option<AttributeDescriptor> {
375        unsafe {
376            let ptr: *mut c_void =
377                msg_send_1(self.as_ptr(), sel!(objectAtIndexedSubscript:), index);
378            if ptr.is_null() {
379                return None;
380            }
381            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
382            AttributeDescriptor::from_raw(ptr)
383        }
384    }
385
386    /// Set the descriptor at the specified index.
387    ///
388    /// C++ equivalent: `void setObject(const AttributeDescriptor*, NS::UInteger)`
389    pub fn set_object_at(&self, descriptor: &AttributeDescriptor, index: UInteger) {
390        unsafe {
391            msg_send_2::<(), *const c_void, UInteger>(
392                self.as_ptr(),
393                sel!(setObject: atIndexedSubscript:),
394                descriptor.as_ptr(),
395                index,
396            );
397        }
398    }
399}
400
401impl Clone for AttributeDescriptorArray {
402    fn clone(&self) -> Self {
403        unsafe {
404            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
405        }
406        Self(self.0)
407    }
408}
409
410impl Drop for AttributeDescriptorArray {
411    fn drop(&mut self) {
412        unsafe {
413            msg_send_0::<()>(self.as_ptr(), sel!(release));
414        }
415    }
416}
417
418impl Referencing for AttributeDescriptorArray {
419    #[inline]
420    fn as_ptr(&self) -> *const c_void {
421        self.0.as_ptr()
422    }
423}
424
425unsafe impl Send for AttributeDescriptorArray {}
426unsafe impl Sync for AttributeDescriptorArray {}
427
428// ============================================================================
429// StageInputOutputDescriptor
430// ============================================================================
431
432/// Describes the input and output data for a compute kernel stage.
433///
434/// C++ equivalent: `MTL::StageInputOutputDescriptor`
435#[repr(transparent)]
436pub struct StageInputOutputDescriptor(pub(crate) NonNull<c_void>);
437
438impl StageInputOutputDescriptor {
439    /// Create a new stage input/output descriptor.
440    ///
441    /// C++ equivalent: `StageInputOutputDescriptor::stageInputOutputDescriptor()`
442    pub fn new() -> Option<Self> {
443        unsafe {
444            let class = mtl_sys::class!(MTLStageInputOutputDescriptor);
445            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(stageInputOutputDescriptor));
446            Self::from_raw(ptr)
447        }
448    }
449
450    /// Create from a raw pointer.
451    #[inline]
452    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
453        NonNull::new(ptr).map(Self)
454    }
455
456    /// Get the raw pointer.
457    #[inline]
458    pub fn as_raw(&self) -> *mut c_void {
459        self.0.as_ptr()
460    }
461
462    /// Get the buffer layouts array.
463    ///
464    /// C++ equivalent: `BufferLayoutDescriptorArray* layouts() const`
465    pub fn layouts(&self) -> Option<BufferLayoutDescriptorArray> {
466        unsafe {
467            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(layouts));
468            if ptr.is_null() {
469                return None;
470            }
471            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
472            BufferLayoutDescriptorArray::from_raw(ptr)
473        }
474    }
475
476    /// Get the attributes array.
477    ///
478    /// C++ equivalent: `AttributeDescriptorArray* attributes() const`
479    pub fn attributes(&self) -> Option<AttributeDescriptorArray> {
480        unsafe {
481            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(attributes));
482            if ptr.is_null() {
483                return None;
484            }
485            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
486            AttributeDescriptorArray::from_raw(ptr)
487        }
488    }
489
490    /// Get the index type.
491    ///
492    /// C++ equivalent: `IndexType indexType() const`
493    #[inline]
494    pub fn index_type(&self) -> IndexType {
495        unsafe { msg_send_0(self.as_ptr(), sel!(indexType)) }
496    }
497
498    /// Set the index type.
499    ///
500    /// C++ equivalent: `void setIndexType(IndexType)`
501    #[inline]
502    pub fn set_index_type(&self, index_type: IndexType) {
503        unsafe {
504            msg_send_1::<(), IndexType>(self.as_ptr(), sel!(setIndexType:), index_type);
505        }
506    }
507
508    /// Get the index buffer index.
509    ///
510    /// C++ equivalent: `NS::UInteger indexBufferIndex() const`
511    #[inline]
512    pub fn index_buffer_index(&self) -> UInteger {
513        unsafe { msg_send_0(self.as_ptr(), sel!(indexBufferIndex)) }
514    }
515
516    /// Set the index buffer index.
517    ///
518    /// C++ equivalent: `void setIndexBufferIndex(NS::UInteger)`
519    #[inline]
520    pub fn set_index_buffer_index(&self, index: UInteger) {
521        unsafe {
522            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setIndexBufferIndex:), index);
523        }
524    }
525
526    /// Reset the descriptor to its default state.
527    ///
528    /// C++ equivalent: `void reset()`
529    #[inline]
530    pub fn reset(&self) {
531        unsafe {
532            msg_send_0::<()>(self.as_ptr(), sel!(reset));
533        }
534    }
535}
536
537impl Default for StageInputOutputDescriptor {
538    fn default() -> Self {
539        Self::new().expect("failed to create StageInputOutputDescriptor")
540    }
541}
542
543impl Clone for StageInputOutputDescriptor {
544    fn clone(&self) -> Self {
545        unsafe {
546            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
547        }
548        Self(self.0)
549    }
550}
551
552impl Drop for StageInputOutputDescriptor {
553    fn drop(&mut self) {
554        unsafe {
555            msg_send_0::<()>(self.as_ptr(), sel!(release));
556        }
557    }
558}
559
560impl Referencing for StageInputOutputDescriptor {
561    #[inline]
562    fn as_ptr(&self) -> *const c_void {
563        self.0.as_ptr()
564    }
565}
566
567unsafe impl Send for StageInputOutputDescriptor {}
568unsafe impl Sync for StageInputOutputDescriptor {}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn test_type_sizes() {
576        assert_eq!(
577            std::mem::size_of::<BufferLayoutDescriptor>(),
578            std::mem::size_of::<*mut c_void>()
579        );
580        assert_eq!(
581            std::mem::size_of::<AttributeDescriptor>(),
582            std::mem::size_of::<*mut c_void>()
583        );
584        assert_eq!(
585            std::mem::size_of::<StageInputOutputDescriptor>(),
586            std::mem::size_of::<*mut c_void>()
587        );
588    }
589}