Skip to main content

mtl_gpu/library/
function_descriptor.rs

1//! Descriptor for creating specialized functions.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::Referencing;
7use mtl_sys::{msg_send_0, msg_send_1, sel};
8
9use crate::enums::FunctionOptions;
10
11use super::FunctionConstantValues;
12
13/// Descriptor for creating specialized functions.
14///
15/// C++ equivalent: `MTL::FunctionDescriptor`
16///
17/// Used to create specialized functions from a library.
18#[repr(transparent)]
19pub struct FunctionDescriptor(pub(crate) NonNull<c_void>);
20
21impl FunctionDescriptor {
22    /// Allocate a new function descriptor.
23    ///
24    /// C++ equivalent: `static FunctionDescriptor* alloc()`
25    pub fn alloc() -> Option<Self> {
26        unsafe {
27            let cls = mtl_sys::Class::get("MTLFunctionDescriptor")?;
28            let ptr: *mut c_void = msg_send_0(cls.as_ptr(), sel!(alloc));
29            Self::from_raw(ptr)
30        }
31    }
32
33    /// Initialize an allocated function descriptor.
34    ///
35    /// C++ equivalent: `FunctionDescriptor* init()`
36    pub fn init(&self) -> Option<Self> {
37        unsafe {
38            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(init));
39            Self::from_raw(ptr)
40        }
41    }
42
43    /// Create a new function descriptor.
44    pub fn new() -> Option<Self> {
45        Self::alloc()?.init()
46    }
47
48    /// Create a function descriptor using the factory method.
49    ///
50    /// C++ equivalent: `static FunctionDescriptor* functionDescriptor()`
51    pub fn function_descriptor() -> Option<Self> {
52        unsafe {
53            let cls = mtl_sys::Class::get("MTLFunctionDescriptor")?;
54            let ptr: *mut c_void = msg_send_0(cls.as_ptr(), sel!(functionDescriptor));
55            if ptr.is_null() {
56                return None;
57            }
58            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
59            Self::from_raw(ptr)
60        }
61    }
62
63    /// Create from a raw pointer.
64    ///
65    /// # Safety
66    ///
67    /// The pointer must be a valid Metal function descriptor object.
68    #[inline]
69    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
70        NonNull::new(ptr).map(Self)
71    }
72
73    /// Get the raw pointer.
74    #[inline]
75    pub fn as_raw(&self) -> *mut c_void {
76        self.0.as_ptr()
77    }
78
79    /// Get the function name.
80    ///
81    /// C++ equivalent: `NS::String* name() const`
82    pub fn name(&self) -> Option<String> {
83        unsafe {
84            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(name));
85            if ptr.is_null() {
86                return None;
87            }
88            let utf8_ptr: *const std::ffi::c_char =
89                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
90            if utf8_ptr.is_null() {
91                return None;
92            }
93            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
94            Some(c_str.to_string_lossy().into_owned())
95        }
96    }
97
98    /// Set the function name.
99    ///
100    /// C++ equivalent: `void setName(const NS::String*)`
101    pub fn set_name(&self, name: &str) {
102        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
103            unsafe {
104                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setName:), ns_name.as_ptr());
105            }
106        }
107    }
108
109    /// Get the specialized name.
110    ///
111    /// C++ equivalent: `NS::String* specializedName() const`
112    pub fn specialized_name(&self) -> Option<String> {
113        unsafe {
114            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(specializedName));
115            if ptr.is_null() {
116                return None;
117            }
118            let utf8_ptr: *const std::ffi::c_char =
119                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
120            if utf8_ptr.is_null() {
121                return None;
122            }
123            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
124            Some(c_str.to_string_lossy().into_owned())
125        }
126    }
127
128    /// Set the specialized name.
129    ///
130    /// C++ equivalent: `void setSpecializedName(const NS::String*)`
131    pub fn set_specialized_name(&self, name: &str) {
132        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
133            unsafe {
134                msg_send_1::<(), *const c_void>(
135                    self.as_ptr(),
136                    sel!(setSpecializedName:),
137                    ns_name.as_ptr(),
138                );
139            }
140        }
141    }
142
143    /// Get the function constant values.
144    ///
145    /// C++ equivalent: `FunctionConstantValues* constantValues() const`
146    pub fn constant_values(&self) -> Option<FunctionConstantValues> {
147        unsafe {
148            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(constantValues));
149            if ptr.is_null() {
150                return None;
151            }
152            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
153            FunctionConstantValues::from_raw(ptr)
154        }
155    }
156
157    /// Set the function constant values.
158    ///
159    /// C++ equivalent: `void setConstantValues(const FunctionConstantValues*)`
160    pub fn set_constant_values(&self, values: Option<&FunctionConstantValues>) {
161        let ptr = values.map(|v| v.as_ptr()).unwrap_or(std::ptr::null());
162        unsafe {
163            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setConstantValues:), ptr);
164        }
165    }
166
167    /// Get the function options.
168    ///
169    /// C++ equivalent: `FunctionOptions options() const`
170    #[inline]
171    pub fn options(&self) -> FunctionOptions {
172        unsafe { msg_send_0(self.as_ptr(), sel!(options)) }
173    }
174
175    /// Set the function options.
176    ///
177    /// C++ equivalent: `void setOptions(FunctionOptions)`
178    #[inline]
179    pub fn set_options(&self, options: FunctionOptions) {
180        unsafe {
181            msg_send_1::<(), FunctionOptions>(self.as_ptr(), sel!(setOptions:), options);
182        }
183    }
184
185    /// Get the binary archives (raw NSArray pointer).
186    ///
187    /// C++ equivalent: `NS::Array* binaryArchives() const`
188    pub fn binary_archives_raw(&self) -> *mut c_void {
189        unsafe { msg_send_0(self.as_ptr(), sel!(binaryArchives)) }
190    }
191
192    /// Set the binary archives.
193    ///
194    /// C++ equivalent: `void setBinaryArchives(const NS::Array*)`
195    ///
196    /// # Safety
197    ///
198    /// The archives pointer must be a valid NSArray of BinaryArchive objects or null.
199    pub unsafe fn set_binary_archives_raw(&self, archives: *const c_void) {
200        unsafe {
201            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setBinaryArchives:), archives);
202        }
203    }
204}
205
206impl Default for FunctionDescriptor {
207    fn default() -> Self {
208        Self::new().expect("failed to create FunctionDescriptor")
209    }
210}
211
212impl Clone for FunctionDescriptor {
213    fn clone(&self) -> Self {
214        unsafe {
215            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
216            Self::from_raw(ptr).expect("failed to copy FunctionDescriptor")
217        }
218    }
219}
220
221impl Drop for FunctionDescriptor {
222    fn drop(&mut self) {
223        unsafe {
224            msg_send_0::<()>(self.as_ptr(), sel!(release));
225        }
226    }
227}
228
229impl Referencing for FunctionDescriptor {
230    #[inline]
231    fn as_ptr(&self) -> *const c_void {
232        self.0.as_ptr()
233    }
234}
235
236unsafe impl Send for FunctionDescriptor {}
237unsafe impl Sync for FunctionDescriptor {}
238
239impl std::fmt::Debug for FunctionDescriptor {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct("FunctionDescriptor")
242            .field("name", &self.name())
243            .field("specialized_name", &self.specialized_name())
244            .field("options", &self.options())
245            .finish()
246    }
247}