Skip to main content

mtl_gpu/mtl4/
argument_table.rs

1//! MTL4 ArgumentTable implementation.
2//!
3//! Corresponds to `Metal/MTL4ArgumentTable.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_3, sel};
10
11use crate::Device;
12
13// ============================================================
14// ArgumentTableDescriptor
15// ============================================================
16
17/// Descriptor for creating an argument table.
18///
19/// C++ equivalent: `MTL4::ArgumentTableDescriptor`
20#[repr(transparent)]
21pub struct ArgumentTableDescriptor(NonNull<c_void>);
22
23impl ArgumentTableDescriptor {
24    /// Create an ArgumentTableDescriptor from a raw pointer.
25    #[inline]
26    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
27        NonNull::new(ptr).map(Self)
28    }
29
30    /// Get the raw pointer.
31    #[inline]
32    pub fn as_raw(&self) -> *mut c_void {
33        self.0.as_ptr()
34    }
35
36    /// Create a new argument table descriptor.
37    pub fn new() -> Option<Self> {
38        unsafe {
39            let class = mtl_sys::Class::get("MTL4ArgumentTableDescriptor")?;
40            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
41            if ptr.is_null() {
42                return None;
43            }
44            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
45            Self::from_raw(ptr)
46        }
47    }
48
49    /// Get the label.
50    ///
51    /// C++ equivalent: `NS::String* label() const`
52    pub fn label(&self) -> Option<String> {
53        unsafe {
54            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
55            if ns_string.is_null() {
56                return None;
57            }
58            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
59            if c_str.is_null() {
60                return None;
61            }
62            Some(
63                std::ffi::CStr::from_ptr(c_str)
64                    .to_string_lossy()
65                    .into_owned(),
66            )
67        }
68    }
69
70    /// Set the label.
71    ///
72    /// C++ equivalent: `void setLabel(const NS::String*)`
73    pub fn set_label(&self, label: &str) {
74        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
75            unsafe {
76                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
77            }
78        }
79    }
80
81    /// Get the maximum buffer bind count.
82    ///
83    /// C++ equivalent: `NS::UInteger maxBufferBindCount() const`
84    pub fn max_buffer_bind_count(&self) -> UInteger {
85        unsafe { msg_send_0(self.as_ptr(), sel!(maxBufferBindCount)) }
86    }
87
88    /// Set the maximum buffer bind count.
89    ///
90    /// C++ equivalent: `void setMaxBufferBindCount(NS::UInteger)`
91    pub fn set_max_buffer_bind_count(&self, count: UInteger) {
92        unsafe {
93            let _: () = msg_send_1(self.as_ptr(), sel!(setMaxBufferBindCount:), count);
94        }
95    }
96
97    /// Get the maximum texture bind count.
98    ///
99    /// C++ equivalent: `NS::UInteger maxTextureBindCount() const`
100    pub fn max_texture_bind_count(&self) -> UInteger {
101        unsafe { msg_send_0(self.as_ptr(), sel!(maxTextureBindCount)) }
102    }
103
104    /// Set the maximum texture bind count.
105    ///
106    /// C++ equivalent: `void setMaxTextureBindCount(NS::UInteger)`
107    pub fn set_max_texture_bind_count(&self, count: UInteger) {
108        unsafe {
109            let _: () = msg_send_1(self.as_ptr(), sel!(setMaxTextureBindCount:), count);
110        }
111    }
112
113    /// Get the maximum sampler state bind count.
114    ///
115    /// C++ equivalent: `NS::UInteger maxSamplerStateBindCount() const`
116    pub fn max_sampler_state_bind_count(&self) -> UInteger {
117        unsafe { msg_send_0(self.as_ptr(), sel!(maxSamplerStateBindCount)) }
118    }
119
120    /// Set the maximum sampler state bind count.
121    ///
122    /// C++ equivalent: `void setMaxSamplerStateBindCount(NS::UInteger)`
123    pub fn set_max_sampler_state_bind_count(&self, count: UInteger) {
124        unsafe {
125            let _: () = msg_send_1(self.as_ptr(), sel!(setMaxSamplerStateBindCount:), count);
126        }
127    }
128
129    /// Check if bindings should be initialized.
130    ///
131    /// C++ equivalent: `bool initializeBindings() const`
132    pub fn initialize_bindings(&self) -> bool {
133        unsafe { msg_send_0(self.as_ptr(), sel!(initializeBindings)) }
134    }
135
136    /// Set whether bindings should be initialized.
137    ///
138    /// C++ equivalent: `void setInitializeBindings(bool)`
139    pub fn set_initialize_bindings(&self, initialize: bool) {
140        unsafe {
141            let _: () = msg_send_1(self.as_ptr(), sel!(setInitializeBindings:), initialize);
142        }
143    }
144
145    /// Check if attribute strides are supported.
146    ///
147    /// C++ equivalent: `bool supportAttributeStrides() const`
148    pub fn support_attribute_strides(&self) -> bool {
149        unsafe { msg_send_0(self.as_ptr(), sel!(supportAttributeStrides)) }
150    }
151
152    /// Set whether attribute strides are supported.
153    ///
154    /// C++ equivalent: `void setSupportAttributeStrides(bool)`
155    pub fn set_support_attribute_strides(&self, support: bool) {
156        unsafe {
157            let _: () = msg_send_1(self.as_ptr(), sel!(setSupportAttributeStrides:), support);
158        }
159    }
160}
161
162impl Default for ArgumentTableDescriptor {
163    fn default() -> Self {
164        Self::new().expect("Failed to create MTL4ArgumentTableDescriptor")
165    }
166}
167
168impl Clone for ArgumentTableDescriptor {
169    fn clone(&self) -> Self {
170        unsafe {
171            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
172        }
173        Self(self.0)
174    }
175}
176
177impl Drop for ArgumentTableDescriptor {
178    fn drop(&mut self) {
179        unsafe {
180            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
181        }
182    }
183}
184
185impl Referencing for ArgumentTableDescriptor {
186    #[inline]
187    fn as_ptr(&self) -> *const c_void {
188        self.0.as_ptr()
189    }
190}
191
192unsafe impl Send for ArgumentTableDescriptor {}
193unsafe impl Sync for ArgumentTableDescriptor {}
194
195impl std::fmt::Debug for ArgumentTableDescriptor {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        f.debug_struct("ArgumentTableDescriptor")
198            .field("label", &self.label())
199            .field("max_buffer_bind_count", &self.max_buffer_bind_count())
200            .field("max_texture_bind_count", &self.max_texture_bind_count())
201            .field(
202                "max_sampler_state_bind_count",
203                &self.max_sampler_state_bind_count(),
204            )
205            .finish()
206    }
207}
208
209// ============================================================
210// ArgumentTable
211// ============================================================
212
213/// Argument table for GPU resource binding.
214///
215/// C++ equivalent: `MTL4::ArgumentTable`
216///
217/// ArgumentTable provides a way to bind resources (buffers, textures, samplers)
218/// at specific indices for use in shaders.
219#[repr(transparent)]
220pub struct ArgumentTable(NonNull<c_void>);
221
222impl ArgumentTable {
223    /// Create an ArgumentTable from a raw pointer.
224    #[inline]
225    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
226        NonNull::new(ptr).map(Self)
227    }
228
229    /// Get the raw pointer.
230    #[inline]
231    pub fn as_raw(&self) -> *mut c_void {
232        self.0.as_ptr()
233    }
234
235    /// Get the device.
236    ///
237    /// C++ equivalent: `MTL::Device* device() const`
238    pub fn device(&self) -> Option<Device> {
239        unsafe {
240            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
241            Device::from_raw(ptr)
242        }
243    }
244
245    /// Get the label.
246    ///
247    /// C++ equivalent: `NS::String* label() const`
248    pub fn label(&self) -> Option<String> {
249        unsafe {
250            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
251            if ns_string.is_null() {
252                return None;
253            }
254            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
255            if c_str.is_null() {
256                return None;
257            }
258            Some(
259                std::ffi::CStr::from_ptr(c_str)
260                    .to_string_lossy()
261                    .into_owned(),
262            )
263        }
264    }
265
266    /// Set a GPU address at the specified binding index.
267    ///
268    /// C++ equivalent: `void setAddress(MTL::GPUAddress, NS::UInteger)`
269    pub fn set_address(&self, gpu_address: u64, binding_index: UInteger) {
270        unsafe {
271            let _: () = msg_send_2(
272                self.as_ptr(),
273                sel!(setAddress:atIndex:),
274                gpu_address,
275                binding_index,
276            );
277        }
278    }
279
280    /// Set a GPU address with stride at the specified binding index.
281    ///
282    /// C++ equivalent: `void setAddress(MTL::GPUAddress, NS::UInteger, NS::UInteger)`
283    pub fn set_address_with_stride(
284        &self,
285        gpu_address: u64,
286        stride: UInteger,
287        binding_index: UInteger,
288    ) {
289        unsafe {
290            let _: () = msg_send_3(
291                self.as_ptr(),
292                sel!(setAddress:attributeStride:atIndex:),
293                gpu_address,
294                stride,
295                binding_index,
296            );
297        }
298    }
299
300    /// Set a resource at the specified buffer index.
301    ///
302    /// C++ equivalent: `void setResource(MTL::ResourceID, NS::UInteger)`
303    pub fn set_resource(&self, resource_id: u64, binding_index: UInteger) {
304        unsafe {
305            let _: () = msg_send_2(
306                self.as_ptr(),
307                sel!(setResource:atBufferIndex:),
308                resource_id,
309                binding_index,
310            );
311        }
312    }
313
314    /// Set a texture at the specified binding index.
315    ///
316    /// C++ equivalent: `void setTexture(MTL::ResourceID, NS::UInteger)`
317    pub fn set_texture(&self, resource_id: u64, binding_index: UInteger) {
318        unsafe {
319            let _: () = msg_send_2(
320                self.as_ptr(),
321                sel!(setTexture:atIndex:),
322                resource_id,
323                binding_index,
324            );
325        }
326    }
327
328    /// Set a sampler state at the specified binding index.
329    ///
330    /// C++ equivalent: `void setSamplerState(MTL::ResourceID, NS::UInteger)`
331    pub fn set_sampler_state(&self, resource_id: u64, binding_index: UInteger) {
332        unsafe {
333            let _: () = msg_send_2(
334                self.as_ptr(),
335                sel!(setSamplerState:atIndex:),
336                resource_id,
337                binding_index,
338            );
339        }
340    }
341}
342
343impl Clone for ArgumentTable {
344    fn clone(&self) -> Self {
345        unsafe {
346            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
347        }
348        Self(self.0)
349    }
350}
351
352impl Drop for ArgumentTable {
353    fn drop(&mut self) {
354        unsafe {
355            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
356        }
357    }
358}
359
360impl Referencing for ArgumentTable {
361    #[inline]
362    fn as_ptr(&self) -> *const c_void {
363        self.0.as_ptr()
364    }
365}
366
367unsafe impl Send for ArgumentTable {}
368unsafe impl Sync for ArgumentTable {}
369
370impl std::fmt::Debug for ArgumentTable {
371    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372        f.debug_struct("ArgumentTable")
373            .field("label", &self.label())
374            .finish()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_argument_table_descriptor_size() {
384        assert_eq!(
385            std::mem::size_of::<ArgumentTableDescriptor>(),
386            std::mem::size_of::<*mut c_void>()
387        );
388    }
389
390    #[test]
391    fn test_argument_table_size() {
392        assert_eq!(
393            std::mem::size_of::<ArgumentTable>(),
394            std::mem::size_of::<*mut c_void>()
395        );
396    }
397}