Skip to main content

mtl_gpu/mtl4/
stitched_function_descriptor.rs

1//! MTL4 StitchedFunctionDescriptor implementation.
2//!
3//! Corresponds to `Metal/MTL4StitchedFunctionDescriptor.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use crate::FunctionStitchingGraph;
12
13// ============================================================
14// StitchedFunctionDescriptor
15// ============================================================
16
17/// Descriptor for stitched functions.
18///
19/// C++ equivalent: `MTL4::StitchedFunctionDescriptor`
20///
21/// StitchedFunctionDescriptor extends FunctionDescriptor to support
22/// function stitching with a graph-based composition model.
23#[repr(transparent)]
24pub struct StitchedFunctionDescriptor(NonNull<c_void>);
25
26impl StitchedFunctionDescriptor {
27    /// Create a StitchedFunctionDescriptor from a raw pointer.
28    #[inline]
29    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30        NonNull::new(ptr).map(Self)
31    }
32
33    /// Get the raw pointer.
34    #[inline]
35    pub fn as_raw(&self) -> *mut c_void {
36        self.0.as_ptr()
37    }
38
39    /// Create a new stitched function descriptor.
40    pub fn new() -> Option<Self> {
41        unsafe {
42            let class = mtl_sys::Class::get("MTL4StitchedFunctionDescriptor")?;
43            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
44            if ptr.is_null() {
45                return None;
46            }
47            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
48            Self::from_raw(ptr)
49        }
50    }
51
52    /// Get the function graph.
53    ///
54    /// C++ equivalent: `MTL::FunctionStitchingGraph* functionGraph() const`
55    pub fn function_graph(&self) -> Option<FunctionStitchingGraph> {
56        unsafe {
57            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(functionGraph));
58            FunctionStitchingGraph::from_raw(ptr)
59        }
60    }
61
62    /// Set the function graph.
63    ///
64    /// C++ equivalent: `void setFunctionGraph(const MTL::FunctionStitchingGraph*)`
65    pub fn set_function_graph(&self, graph: &FunctionStitchingGraph) {
66        unsafe {
67            let _: () = msg_send_1(self.as_ptr(), sel!(setFunctionGraph:), graph.as_ptr());
68        }
69    }
70
71    /// Get the function descriptors array (as raw pointer to NSArray).
72    ///
73    /// C++ equivalent: `NS::Array* functionDescriptors() const`
74    pub fn function_descriptors_raw(&self) -> *mut c_void {
75        unsafe { msg_send_0(self.as_ptr(), sel!(functionDescriptors)) }
76    }
77
78    /// Set the function descriptors array (from raw pointer to NSArray).
79    ///
80    /// C++ equivalent: `void setFunctionDescriptors(const NS::Array*)`
81    pub fn set_function_descriptors_raw(&self, descriptors: *const c_void) {
82        unsafe {
83            let _: () = msg_send_1(self.as_ptr(), sel!(setFunctionDescriptors:), descriptors);
84        }
85    }
86}
87
88impl Default for StitchedFunctionDescriptor {
89    fn default() -> Self {
90        Self::new().expect("Failed to create MTL4StitchedFunctionDescriptor")
91    }
92}
93
94impl Clone for StitchedFunctionDescriptor {
95    fn clone(&self) -> Self {
96        unsafe {
97            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
98        }
99        Self(self.0)
100    }
101}
102
103impl Drop for StitchedFunctionDescriptor {
104    fn drop(&mut self) {
105        unsafe {
106            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
107        }
108    }
109}
110
111impl Referencing for StitchedFunctionDescriptor {
112    #[inline]
113    fn as_ptr(&self) -> *const c_void {
114        self.0.as_ptr()
115    }
116}
117
118unsafe impl Send for StitchedFunctionDescriptor {}
119unsafe impl Sync for StitchedFunctionDescriptor {}
120
121impl std::fmt::Debug for StitchedFunctionDescriptor {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("StitchedFunctionDescriptor").finish()
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_stitched_function_descriptor_size() {
133        assert_eq!(
134            std::mem::size_of::<StitchedFunctionDescriptor>(),
135            std::mem::size_of::<*mut c_void>()
136        );
137    }
138}