Skip to main content

mtl_gpu/mtl4/
binary_function.rs

1//! MTL4 BinaryFunction implementation.
2//!
3//! Corresponds to `Metal/MTL4BinaryFunction.hpp` and `Metal/MTL4BinaryFunctionDescriptor.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use super::FunctionDescriptor;
12use super::enums::BinaryFunctionOptions;
13use crate::FunctionType;
14
15// ============================================================
16// BinaryFunction
17// ============================================================
18
19/// A compiled binary function for linking.
20///
21/// C++ equivalent: `MTL4::BinaryFunction`
22///
23/// BinaryFunction represents a precompiled function that can be
24/// linked with pipelines at runtime.
25#[repr(transparent)]
26pub struct BinaryFunction(NonNull<c_void>);
27
28impl BinaryFunction {
29    /// Create a BinaryFunction from a raw pointer.
30    #[inline]
31    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
32        NonNull::new(ptr).map(Self)
33    }
34
35    /// Get the raw pointer.
36    #[inline]
37    pub fn as_raw(&self) -> *mut c_void {
38        self.0.as_ptr()
39    }
40
41    /// Get the function type.
42    ///
43    /// C++ equivalent: `MTL::FunctionType functionType() const`
44    pub fn function_type(&self) -> FunctionType {
45        unsafe { msg_send_0(self.as_ptr(), sel!(functionType)) }
46    }
47
48    /// Get the function name.
49    ///
50    /// C++ equivalent: `NS::String* name() const`
51    pub fn name(&self) -> Option<String> {
52        unsafe {
53            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(name));
54            if ns_string.is_null() {
55                return None;
56            }
57            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
58            if c_str.is_null() {
59                return None;
60            }
61            Some(
62                std::ffi::CStr::from_ptr(c_str)
63                    .to_string_lossy()
64                    .into_owned(),
65            )
66        }
67    }
68}
69
70impl Clone for BinaryFunction {
71    fn clone(&self) -> Self {
72        unsafe {
73            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
74        }
75        Self(self.0)
76    }
77}
78
79impl Drop for BinaryFunction {
80    fn drop(&mut self) {
81        unsafe {
82            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
83        }
84    }
85}
86
87impl Referencing for BinaryFunction {
88    #[inline]
89    fn as_ptr(&self) -> *const c_void {
90        self.0.as_ptr()
91    }
92}
93
94unsafe impl Send for BinaryFunction {}
95unsafe impl Sync for BinaryFunction {}
96
97impl std::fmt::Debug for BinaryFunction {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("BinaryFunction")
100            .field("name", &self.name())
101            .field("function_type", &self.function_type())
102            .finish()
103    }
104}
105
106// ============================================================
107// BinaryFunctionDescriptor
108// ============================================================
109
110/// Descriptor for creating a binary function.
111///
112/// C++ equivalent: `MTL4::BinaryFunctionDescriptor`
113#[repr(transparent)]
114pub struct BinaryFunctionDescriptor(NonNull<c_void>);
115
116impl BinaryFunctionDescriptor {
117    /// Create a BinaryFunctionDescriptor from a raw pointer.
118    #[inline]
119    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
120        NonNull::new(ptr).map(Self)
121    }
122
123    /// Get the raw pointer.
124    #[inline]
125    pub fn as_raw(&self) -> *mut c_void {
126        self.0.as_ptr()
127    }
128
129    /// Create a new binary function descriptor.
130    pub fn new() -> Option<Self> {
131        unsafe {
132            let class = mtl_sys::Class::get("MTL4BinaryFunctionDescriptor")?;
133            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
134            if ptr.is_null() {
135                return None;
136            }
137            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
138            Self::from_raw(ptr)
139        }
140    }
141
142    /// Get the name.
143    ///
144    /// C++ equivalent: `NS::String* name() const`
145    pub fn name(&self) -> Option<String> {
146        unsafe {
147            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(name));
148            if ns_string.is_null() {
149                return None;
150            }
151            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
152            if c_str.is_null() {
153                return None;
154            }
155            Some(
156                std::ffi::CStr::from_ptr(c_str)
157                    .to_string_lossy()
158                    .into_owned(),
159            )
160        }
161    }
162
163    /// Set the name.
164    ///
165    /// C++ equivalent: `void setName(const NS::String*)`
166    pub fn set_name(&self, name: &str) {
167        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
168            unsafe {
169                let _: () = msg_send_1(self.as_ptr(), sel!(setName:), ns_name.as_ptr());
170            }
171        }
172    }
173
174    /// Get the function descriptor.
175    ///
176    /// C++ equivalent: `FunctionDescriptor* functionDescriptor() const`
177    pub fn function_descriptor(&self) -> Option<FunctionDescriptor> {
178        unsafe {
179            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(functionDescriptor));
180            FunctionDescriptor::from_raw(ptr)
181        }
182    }
183
184    /// Set the function descriptor.
185    ///
186    /// C++ equivalent: `void setFunctionDescriptor(const MTL4::FunctionDescriptor*)`
187    pub fn set_function_descriptor(&self, descriptor: &FunctionDescriptor) {
188        unsafe {
189            let _: () = msg_send_1(
190                self.as_ptr(),
191                sel!(setFunctionDescriptor:),
192                descriptor.as_ptr(),
193            );
194        }
195    }
196
197    /// Get the options.
198    ///
199    /// C++ equivalent: `BinaryFunctionOptions options() const`
200    pub fn options(&self) -> BinaryFunctionOptions {
201        unsafe { msg_send_0(self.as_ptr(), sel!(options)) }
202    }
203
204    /// Set the options.
205    ///
206    /// C++ equivalent: `void setOptions(MTL4::BinaryFunctionOptions)`
207    pub fn set_options(&self, options: BinaryFunctionOptions) {
208        unsafe {
209            let _: () = msg_send_1(self.as_ptr(), sel!(setOptions:), options);
210        }
211    }
212}
213
214impl Clone for BinaryFunctionDescriptor {
215    fn clone(&self) -> Self {
216        unsafe {
217            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
218        }
219        Self(self.0)
220    }
221}
222
223impl Drop for BinaryFunctionDescriptor {
224    fn drop(&mut self) {
225        unsafe {
226            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
227        }
228    }
229}
230
231impl Referencing for BinaryFunctionDescriptor {
232    #[inline]
233    fn as_ptr(&self) -> *const c_void {
234        self.0.as_ptr()
235    }
236}
237
238unsafe impl Send for BinaryFunctionDescriptor {}
239unsafe impl Sync for BinaryFunctionDescriptor {}
240
241impl std::fmt::Debug for BinaryFunctionDescriptor {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("BinaryFunctionDescriptor")
244            .field("name", &self.name())
245            .field("options", &self.options())
246            .finish()
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_binary_function_size() {
256        assert_eq!(
257            std::mem::size_of::<BinaryFunction>(),
258            std::mem::size_of::<*mut c_void>()
259        );
260    }
261
262    #[test]
263    fn test_binary_function_descriptor_size() {
264        assert_eq!(
265            std::mem::size_of::<BinaryFunctionDescriptor>(),
266            std::mem::size_of::<*mut c_void>()
267        );
268    }
269}