Skip to main content

mtl_gpu/library/
function.rs

1//! A compiled shader function.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
8
9use crate::enums::{FunctionOptions, FunctionType, PatchType};
10
11/// A compiled shader function.
12///
13/// C++ equivalent: `MTL::Function`
14#[repr(transparent)]
15pub struct Function(pub(crate) NonNull<c_void>);
16
17impl Function {
18    /// Create a Function from a raw pointer.
19    ///
20    /// # Safety
21    ///
22    /// The pointer must be a valid Metal function object.
23    #[inline]
24    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
25        NonNull::new(ptr).map(Self)
26    }
27
28    /// Get the raw pointer to the function.
29    #[inline]
30    pub fn as_raw(&self) -> *mut c_void {
31        self.0.as_ptr()
32    }
33
34    // =========================================================================
35    // Properties
36    // =========================================================================
37
38    /// Get the label for this function.
39    ///
40    /// C++ equivalent: `NS::String* label() const`
41    pub fn label(&self) -> Option<String> {
42        unsafe {
43            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
44            if ptr.is_null() {
45                return None;
46            }
47            let utf8_ptr: *const std::ffi::c_char =
48                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
49            if utf8_ptr.is_null() {
50                return None;
51            }
52            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
53            Some(c_str.to_string_lossy().into_owned())
54        }
55    }
56
57    /// Set the label for this function.
58    ///
59    /// C++ equivalent: `void setLabel(const NS::String*)`
60    pub fn set_label(&self, label: &str) {
61        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
62            unsafe {
63                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
64            }
65        }
66    }
67
68    /// Get the device that created this function.
69    ///
70    /// C++ equivalent: `Device* device() const`
71    pub fn device(&self) -> crate::Device {
72        unsafe {
73            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
74            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
75            crate::Device::from_raw(ptr).expect("function has no device")
76        }
77    }
78
79    /// Get the function name.
80    ///
81    /// C++ equivalent: `NS::String* name() const`
82    pub fn name(&self) -> Option<String> {
83        unsafe {
84            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(name));
85            if ptr.is_null() {
86                return None;
87            }
88            let utf8_ptr: *const std::ffi::c_char =
89                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
90            if utf8_ptr.is_null() {
91                return None;
92            }
93            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
94            Some(c_str.to_string_lossy().into_owned())
95        }
96    }
97
98    /// Get the function type.
99    ///
100    /// C++ equivalent: `FunctionType functionType() const`
101    #[inline]
102    pub fn function_type(&self) -> FunctionType {
103        unsafe { msg_send_0(self.as_ptr(), sel!(functionType)) }
104    }
105
106    /// Get the patch type (for tessellation shaders).
107    ///
108    /// C++ equivalent: `PatchType patchType() const`
109    #[inline]
110    pub fn patch_type(&self) -> PatchType {
111        unsafe { msg_send_0(self.as_ptr(), sel!(patchType)) }
112    }
113
114    /// Get the patch control point count (for tessellation shaders).
115    ///
116    /// C++ equivalent: `NS::Integer patchControlPointCount() const`
117    #[inline]
118    pub fn patch_control_point_count(&self) -> mtl_foundation::Integer {
119        unsafe { msg_send_0(self.as_ptr(), sel!(patchControlPointCount)) }
120    }
121
122    /// Get the function options.
123    ///
124    /// C++ equivalent: `FunctionOptions options() const`
125    #[inline]
126    pub fn options(&self) -> FunctionOptions {
127        unsafe { msg_send_0(self.as_ptr(), sel!(options)) }
128    }
129
130    // =========================================================================
131    // Reflection
132    // =========================================================================
133
134    /// Get the function constants dictionary (raw NSDictionary pointer).
135    ///
136    /// C++ equivalent: `NS::Dictionary* functionConstantsDictionary() const`
137    pub fn function_constants_dictionary_raw(&self) -> *mut c_void {
138        unsafe { msg_send_0(self.as_ptr(), sel!(functionConstantsDictionary)) }
139    }
140
141    /// Get the stage input attributes (raw NSArray pointer).
142    ///
143    /// C++ equivalent: `NS::Array* stageInputAttributes() const`
144    ///
145    /// Returns an array of Attribute objects.
146    pub fn stage_input_attributes_raw(&self) -> *mut c_void {
147        unsafe { msg_send_0(self.as_ptr(), sel!(stageInputAttributes)) }
148    }
149
150    /// Get the vertex attributes (raw NSArray pointer).
151    ///
152    /// C++ equivalent: `NS::Array* vertexAttributes() const`
153    ///
154    /// Returns an array of VertexAttribute objects (for vertex functions).
155    pub fn vertex_attributes_raw(&self) -> *mut c_void {
156        unsafe { msg_send_0(self.as_ptr(), sel!(vertexAttributes)) }
157    }
158
159    // =========================================================================
160    // Argument Encoder Creation
161    // =========================================================================
162
163    /// Create a new argument encoder for the buffer at the specified index.
164    ///
165    /// C++ equivalent: `ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex)`
166    pub fn new_argument_encoder(&self, buffer_index: UInteger) -> Option<crate::ArgumentEncoder> {
167        unsafe {
168            let ptr: *mut c_void = msg_send_1(
169                self.as_ptr(),
170                sel!(newArgumentEncoderWithBufferIndex:),
171                buffer_index,
172            );
173            crate::ArgumentEncoder::from_raw(ptr)
174        }
175    }
176
177    /// Create a new argument encoder with reflection for the buffer at the specified index.
178    ///
179    /// C++ equivalent: `ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex, const AutoreleasedArgument* reflection)`
180    ///
181    /// # Safety
182    ///
183    /// The reflection pointer must be a valid pointer to an Argument pointer, or null.
184    pub unsafe fn new_argument_encoder_with_reflection(
185        &self,
186        buffer_index: UInteger,
187        reflection: *mut *mut c_void,
188    ) -> Option<crate::ArgumentEncoder> {
189        unsafe {
190            let ptr: *mut c_void = msg_send_2(
191                self.as_ptr(),
192                sel!(newArgumentEncoderWithBufferIndex: reflection:),
193                buffer_index,
194                reflection,
195            );
196            crate::ArgumentEncoder::from_raw(ptr)
197        }
198    }
199}
200
201impl Clone for Function {
202    fn clone(&self) -> Self {
203        unsafe {
204            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
205        }
206        Self(self.0)
207    }
208}
209
210impl Drop for Function {
211    fn drop(&mut self) {
212        unsafe {
213            msg_send_0::<()>(self.as_ptr(), sel!(release));
214        }
215    }
216}
217
218impl Referencing for Function {
219    #[inline]
220    fn as_ptr(&self) -> *const c_void {
221        self.0.as_ptr()
222    }
223}
224
225unsafe impl Send for Function {}
226unsafe impl Sync for Function {}
227
228impl std::fmt::Debug for Function {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        f.debug_struct("Function")
231            .field("name", &self.name())
232            .field("function_type", &self.function_type())
233            .field("label", &self.label())
234            .finish()
235    }
236}