Skip to main content

mtl_gpu/mtl4/
archive.rs

1//! MTL4 Archive implementation.
2//!
3//! Corresponds to `Metal/MTL4Archive.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_3, sel};
10
11use super::{
12    BinaryFunction, BinaryFunctionDescriptor, ComputePipelineDescriptor, PipelineDescriptor,
13    PipelineStageDynamicLinkingDescriptor, RenderPipelineDynamicLinkingDescriptor,
14};
15use crate::{ComputePipelineState, RenderPipelineState};
16
17/// Helper to create a generic error.
18fn generic_error() -> mtl_foundation::Error {
19    mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut())
20        .expect("failed to create error object")
21}
22
23// ============================================================
24// Archive
25// ============================================================
26
27/// MTL4 pipeline archive.
28///
29/// C++ equivalent: `MTL4::Archive`
30///
31/// Archive provides access to pre-compiled pipelines and binary functions,
32/// allowing for faster pipeline creation without runtime compilation.
33#[repr(transparent)]
34pub struct Archive(NonNull<c_void>);
35
36impl Archive {
37    /// Create an Archive from a raw pointer.
38    #[inline]
39    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
40        NonNull::new(ptr).map(Self)
41    }
42
43    /// Get the raw pointer.
44    #[inline]
45    pub fn as_raw(&self) -> *mut c_void {
46        self.0.as_ptr()
47    }
48
49    /// Get the label.
50    ///
51    /// C++ equivalent: `NS::String* label() const`
52    pub fn label(&self) -> Option<String> {
53        unsafe {
54            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
55            if ns_string.is_null() {
56                return None;
57            }
58            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
59            if c_str.is_null() {
60                return None;
61            }
62            Some(
63                std::ffi::CStr::from_ptr(c_str)
64                    .to_string_lossy()
65                    .into_owned(),
66            )
67        }
68    }
69
70    /// Set the label.
71    ///
72    /// C++ equivalent: `void setLabel(const NS::String*)`
73    pub fn set_label(&self, label: &str) {
74        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
75            unsafe {
76                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
77            }
78        }
79    }
80
81    // ========== Binary Function Creation ==========
82
83    /// Create a new binary function from the archive.
84    ///
85    /// C++ equivalent: `BinaryFunction* newBinaryFunction(const MTL4::BinaryFunctionDescriptor*, NS::Error**)`
86    pub fn new_binary_function(
87        &self,
88        descriptor: &BinaryFunctionDescriptor,
89    ) -> Result<BinaryFunction, mtl_foundation::Error> {
90        unsafe {
91            let mut error: *mut c_void = std::ptr::null_mut();
92            let ptr: *mut c_void = msg_send_2(
93                self.as_ptr(),
94                sel!(newBinaryFunctionWithDescriptor:error:),
95                descriptor.as_ptr(),
96                &mut error as *mut _,
97            );
98            if !error.is_null() {
99                if let Some(err) = mtl_foundation::Error::from_ptr(error) {
100                    return Err(err);
101                }
102            }
103            BinaryFunction::from_raw(ptr).ok_or_else(generic_error)
104        }
105    }
106
107    // ========== Compute Pipeline Creation ==========
108
109    /// Create a new compute pipeline state from the archive.
110    ///
111    /// C++ equivalent: `MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor*, NS::Error**)`
112    pub fn new_compute_pipeline_state(
113        &self,
114        descriptor: &ComputePipelineDescriptor,
115    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
116        unsafe {
117            let mut error: *mut c_void = std::ptr::null_mut();
118            let ptr: *mut c_void = msg_send_2(
119                self.as_ptr(),
120                sel!(newComputePipelineStateWithDescriptor:error:),
121                descriptor.as_ptr(),
122                &mut error as *mut _,
123            );
124            if !error.is_null() {
125                if let Some(err) = mtl_foundation::Error::from_ptr(error) {
126                    return Err(err);
127                }
128            }
129            ComputePipelineState::from_raw(ptr).ok_or_else(generic_error)
130        }
131    }
132
133    /// Create a new compute pipeline state with dynamic linking from the archive.
134    ///
135    /// C++ equivalent: `MTL::ComputePipelineState* newComputePipelineState(..., dynamicLinkingDescriptor, ...)`
136    pub fn new_compute_pipeline_state_with_dynamic_linking(
137        &self,
138        descriptor: &ComputePipelineDescriptor,
139        dynamic_linking: &PipelineStageDynamicLinkingDescriptor,
140    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
141        unsafe {
142            let mut error: *mut c_void = std::ptr::null_mut();
143            let ptr: *mut c_void = msg_send_3(
144                self.as_ptr(),
145                sel!(newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:error:),
146                descriptor.as_ptr(),
147                dynamic_linking.as_ptr(),
148                &mut error as *mut _,
149            );
150            if !error.is_null() {
151                if let Some(err) = mtl_foundation::Error::from_ptr(error) {
152                    return Err(err);
153                }
154            }
155            ComputePipelineState::from_raw(ptr).ok_or_else(generic_error)
156        }
157    }
158
159    // ========== Render Pipeline Creation ==========
160
161    /// Create a new render pipeline state from the archive.
162    ///
163    /// C++ equivalent: `MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor*, NS::Error**)`
164    pub fn new_render_pipeline_state(
165        &self,
166        descriptor: &PipelineDescriptor,
167    ) -> Result<RenderPipelineState, mtl_foundation::Error> {
168        unsafe {
169            let mut error: *mut c_void = std::ptr::null_mut();
170            let ptr: *mut c_void = msg_send_2(
171                self.as_ptr(),
172                sel!(newRenderPipelineStateWithDescriptor:error:),
173                descriptor.as_ptr(),
174                &mut error as *mut _,
175            );
176            if !error.is_null() {
177                if let Some(err) = mtl_foundation::Error::from_ptr(error) {
178                    return Err(err);
179                }
180            }
181            RenderPipelineState::from_raw(ptr).ok_or_else(generic_error)
182        }
183    }
184
185    /// Create a new render pipeline state with dynamic linking from the archive.
186    ///
187    /// C++ equivalent: `MTL::RenderPipelineState* newRenderPipelineState(..., dynamicLinkingDescriptor, ...)`
188    pub fn new_render_pipeline_state_with_dynamic_linking(
189        &self,
190        descriptor: &PipelineDescriptor,
191        dynamic_linking: &RenderPipelineDynamicLinkingDescriptor,
192    ) -> Result<RenderPipelineState, mtl_foundation::Error> {
193        unsafe {
194            let mut error: *mut c_void = std::ptr::null_mut();
195            let ptr: *mut c_void = msg_send_3(
196                self.as_ptr(),
197                sel!(newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:error:),
198                descriptor.as_ptr(),
199                dynamic_linking.as_ptr(),
200                &mut error as *mut _,
201            );
202            if !error.is_null() {
203                if let Some(err) = mtl_foundation::Error::from_ptr(error) {
204                    return Err(err);
205                }
206            }
207            RenderPipelineState::from_raw(ptr).ok_or_else(generic_error)
208        }
209    }
210}
211
212impl Clone for Archive {
213    fn clone(&self) -> Self {
214        unsafe {
215            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
216        }
217        Self(self.0)
218    }
219}
220
221impl Drop for Archive {
222    fn drop(&mut self) {
223        unsafe {
224            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
225        }
226    }
227}
228
229impl Referencing for Archive {
230    #[inline]
231    fn as_ptr(&self) -> *const c_void {
232        self.0.as_ptr()
233    }
234}
235
236unsafe impl Send for Archive {}
237unsafe impl Sync for Archive {}
238
239impl std::fmt::Debug for Archive {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct("Archive")
242            .field("label", &self.label())
243            .finish()
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_archive_size() {
253        assert_eq!(
254            std::mem::size_of::<Archive>(),
255            std::mem::size_of::<*mut c_void>()
256        );
257    }
258}