Skip to main content

mtl_gpu/mtl4/
specialized_function_descriptor.rs

1//! MTL4 SpecializedFunctionDescriptor implementation.
2//!
3//! Corresponds to `Metal/MTL4SpecializedFunctionDescriptor.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use super::FunctionDescriptor;
12
13// ============================================================
14// SpecializedFunctionDescriptor
15// ============================================================
16
17/// Descriptor for specialized functions with constant values.
18///
19/// C++ equivalent: `MTL4::SpecializedFunctionDescriptor`
20///
21/// SpecializedFunctionDescriptor extends FunctionDescriptor to support
22/// function specialization with constant values.
23#[repr(transparent)]
24pub struct SpecializedFunctionDescriptor(NonNull<c_void>);
25
26impl SpecializedFunctionDescriptor {
27    /// Create a SpecializedFunctionDescriptor from a raw pointer.
28    #[inline]
29    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30        NonNull::new(ptr).map(Self)
31    }
32
33    /// Get the raw pointer.
34    #[inline]
35    pub fn as_raw(&self) -> *mut c_void {
36        self.0.as_ptr()
37    }
38
39    /// Create a new specialized function descriptor.
40    pub fn new() -> Option<Self> {
41        unsafe {
42            let class = mtl_sys::Class::get("MTL4SpecializedFunctionDescriptor")?;
43            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
44            if ptr.is_null() {
45                return None;
46            }
47            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
48            Self::from_raw(ptr)
49        }
50    }
51
52    /// Get the function descriptor.
53    ///
54    /// C++ equivalent: `FunctionDescriptor* functionDescriptor() const`
55    pub fn function_descriptor(&self) -> Option<FunctionDescriptor> {
56        unsafe {
57            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(functionDescriptor));
58            FunctionDescriptor::from_raw(ptr)
59        }
60    }
61
62    /// Set the function descriptor.
63    ///
64    /// C++ equivalent: `void setFunctionDescriptor(const MTL4::FunctionDescriptor*)`
65    pub fn set_function_descriptor(&self, descriptor: &FunctionDescriptor) {
66        unsafe {
67            let _: () = msg_send_1(
68                self.as_ptr(),
69                sel!(setFunctionDescriptor:),
70                descriptor.as_ptr(),
71            );
72        }
73    }
74
75    /// Get the constant values (as raw pointer).
76    ///
77    /// C++ equivalent: `MTL::FunctionConstantValues* constantValues() const`
78    pub fn constant_values_raw(&self) -> *mut c_void {
79        unsafe { msg_send_0(self.as_ptr(), sel!(constantValues)) }
80    }
81
82    /// Set the constant values (from raw pointer).
83    ///
84    /// C++ equivalent: `void setConstantValues(const MTL::FunctionConstantValues*)`
85    pub fn set_constant_values_raw(&self, values: *const c_void) {
86        unsafe {
87            let _: () = msg_send_1(self.as_ptr(), sel!(setConstantValues:), values);
88        }
89    }
90
91    /// Get the specialized name.
92    ///
93    /// C++ equivalent: `NS::String* specializedName() const`
94    pub fn specialized_name(&self) -> Option<String> {
95        unsafe {
96            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(specializedName));
97            if ns_string.is_null() {
98                return None;
99            }
100            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
101            if c_str.is_null() {
102                return None;
103            }
104            Some(
105                std::ffi::CStr::from_ptr(c_str)
106                    .to_string_lossy()
107                    .into_owned(),
108            )
109        }
110    }
111
112    /// Set the specialized name.
113    ///
114    /// C++ equivalent: `void setSpecializedName(const NS::String*)`
115    pub fn set_specialized_name(&self, name: &str) {
116        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
117            unsafe {
118                let _: () = msg_send_1(self.as_ptr(), sel!(setSpecializedName:), ns_name.as_ptr());
119            }
120        }
121    }
122}
123
124impl Default for SpecializedFunctionDescriptor {
125    fn default() -> Self {
126        Self::new().expect("Failed to create MTL4SpecializedFunctionDescriptor")
127    }
128}
129
130impl Clone for SpecializedFunctionDescriptor {
131    fn clone(&self) -> Self {
132        unsafe {
133            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
134        }
135        Self(self.0)
136    }
137}
138
139impl Drop for SpecializedFunctionDescriptor {
140    fn drop(&mut self) {
141        unsafe {
142            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
143        }
144    }
145}
146
147impl Referencing for SpecializedFunctionDescriptor {
148    #[inline]
149    fn as_ptr(&self) -> *const c_void {
150        self.0.as_ptr()
151    }
152}
153
154unsafe impl Send for SpecializedFunctionDescriptor {}
155unsafe impl Sync for SpecializedFunctionDescriptor {}
156
157impl std::fmt::Debug for SpecializedFunctionDescriptor {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("SpecializedFunctionDescriptor")
160            .field("specialized_name", &self.specialized_name())
161            .finish()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_specialized_function_descriptor_size() {
171        assert_eq!(
172            std::mem::size_of::<SpecializedFunctionDescriptor>(),
173            std::mem::size_of::<*mut c_void>()
174        );
175    }
176}