Skip to main content

mtl_gpu/
binary_archive.rs

1//! Metal binary archive for caching compiled pipeline state.
2//!
3//! Corresponds to `Metal/MTLBinaryArchive.hpp`.
4//!
5//! Binary archives store compiled pipeline functions for faster loading.
6
7use std::ffi::c_void;
8use std::ptr::NonNull;
9
10use mtl_foundation::{Referencing, UInteger};
11use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_3, sel};
12
13use crate::Device;
14
15// ============================================================================
16// BinaryArchiveError enum
17// ============================================================================
18
19/// Binary archive error codes.
20///
21/// C++ equivalent: `MTL::BinaryArchiveError`
22#[repr(transparent)]
23#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
24pub struct BinaryArchiveError(pub UInteger);
25
26impl BinaryArchiveError {
27    pub const NONE: Self = Self(0);
28    pub const INVALID_FILE: Self = Self(1);
29    pub const UNEXPECTED_ELEMENT: Self = Self(2);
30    pub const COMPILATION_FAILURE: Self = Self(3);
31    pub const INTERNAL_ERROR: Self = Self(4);
32}
33
34// ============================================================================
35// BinaryArchiveDescriptor
36// ============================================================================
37
38/// Descriptor for creating a binary archive.
39///
40/// C++ equivalent: `MTL::BinaryArchiveDescriptor`
41#[repr(transparent)]
42pub struct BinaryArchiveDescriptor(pub(crate) NonNull<c_void>);
43
44impl BinaryArchiveDescriptor {
45    /// Create a new binary archive descriptor.
46    ///
47    /// C++ equivalent: `BinaryArchiveDescriptor* alloc()->init()`
48    pub fn new() -> Option<Self> {
49        unsafe {
50            let class = mtl_sys::Class::get("MTLBinaryArchiveDescriptor")?;
51            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
52            if ptr.is_null() {
53                return None;
54            }
55            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
56            Self::from_raw(ptr)
57        }
58    }
59
60    /// Create from a raw pointer.
61    ///
62    /// # Safety
63    ///
64    /// The pointer must be a valid Metal BinaryArchiveDescriptor.
65    #[inline]
66    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
67        NonNull::new(ptr).map(Self)
68    }
69
70    /// Get the raw pointer.
71    #[inline]
72    pub fn as_raw(&self) -> *mut c_void {
73        self.0.as_ptr()
74    }
75
76    /// Get the URL to load the archive from.
77    ///
78    /// C++ equivalent: `NS::URL* url() const`
79    pub fn url(&self) -> Option<mtl_foundation::Url> {
80        unsafe {
81            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(url));
82            if ptr.is_null() {
83                return None;
84            }
85            // Retain for our reference
86            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
87            mtl_foundation::Url::from_ptr(ptr)
88        }
89    }
90
91    /// Set the URL to load the archive from.
92    ///
93    /// C++ equivalent: `void setUrl(const NS::URL* url)`
94    pub fn set_url(&self, url: &mtl_foundation::Url) {
95        unsafe {
96            let _: () = msg_send_1(self.as_ptr(), sel!(setUrl:), url.as_ptr());
97        }
98    }
99}
100
101impl Default for BinaryArchiveDescriptor {
102    fn default() -> Self {
103        Self::new().expect("failed to create BinaryArchiveDescriptor")
104    }
105}
106
107impl Clone for BinaryArchiveDescriptor {
108    fn clone(&self) -> Self {
109        unsafe {
110            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
111            Self::from_raw(ptr).expect("copy should succeed")
112        }
113    }
114}
115
116impl Drop for BinaryArchiveDescriptor {
117    fn drop(&mut self) {
118        unsafe {
119            msg_send_0::<()>(self.as_ptr(), sel!(release));
120        }
121    }
122}
123
124impl Referencing for BinaryArchiveDescriptor {
125    #[inline]
126    fn as_ptr(&self) -> *const c_void {
127        self.0.as_ptr()
128    }
129}
130
131unsafe impl Send for BinaryArchiveDescriptor {}
132unsafe impl Sync for BinaryArchiveDescriptor {}
133
134// ============================================================================
135// BinaryArchive
136// ============================================================================
137
138/// A binary archive for caching compiled pipeline state.
139///
140/// C++ equivalent: `MTL::BinaryArchive`
141#[repr(transparent)]
142pub struct BinaryArchive(pub(crate) NonNull<c_void>);
143
144impl BinaryArchive {
145    /// Create from a raw pointer.
146    ///
147    /// # Safety
148    ///
149    /// The pointer must be a valid Metal BinaryArchive.
150    #[inline]
151    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
152        NonNull::new(ptr).map(Self)
153    }
154
155    /// Get the raw pointer.
156    #[inline]
157    pub fn as_raw(&self) -> *mut c_void {
158        self.0.as_ptr()
159    }
160
161    /// Get the device that created this archive.
162    ///
163    /// C++ equivalent: `Device* device() const`
164    pub fn device(&self) -> Device {
165        unsafe {
166            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
167            // Retain for our reference
168            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
169            Device::from_raw(ptr).expect("device should be valid")
170        }
171    }
172
173    /// Get the label.
174    ///
175    /// C++ equivalent: `NS::String* label() const`
176    pub fn label(&self) -> Option<String> {
177        unsafe {
178            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
179            if ptr.is_null() {
180                return None;
181            }
182            let utf8_ptr: *const std::ffi::c_char =
183                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
184            if utf8_ptr.is_null() {
185                return None;
186            }
187            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
188            Some(c_str.to_string_lossy().into_owned())
189        }
190    }
191
192    /// Set the label.
193    ///
194    /// C++ equivalent: `void setLabel(const NS::String* label)`
195    pub fn set_label(&self, label: &str) {
196        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
197            unsafe {
198                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
199            }
200        }
201    }
202
203    /// Add compute pipeline functions to the archive.
204    ///
205    /// C++ equivalent: `bool addComputePipelineFunctions(const MTL::ComputePipelineDescriptor*, NS::Error**)`
206    pub fn add_compute_pipeline_functions(
207        &self,
208        descriptor: &crate::ComputePipelineDescriptor,
209    ) -> Result<(), mtl_foundation::Error> {
210        unsafe {
211            let mut error: *mut c_void = std::ptr::null_mut();
212            let result: bool = msg_send_2(
213                self.as_ptr(),
214                sel!(addComputePipelineFunctionsWithDescriptor:error:),
215                descriptor.as_ptr(),
216                &mut error as *mut _,
217            );
218            if !result {
219                if !error.is_null() {
220                    return Err(
221                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
222                    );
223                }
224                return Err(mtl_foundation::Error::error(
225                    std::ptr::null_mut(),
226                    -1,
227                    std::ptr::null_mut(),
228                )
229                .expect("failed to create error"));
230            }
231            Ok(())
232        }
233    }
234
235    /// Add render pipeline functions to the archive.
236    ///
237    /// C++ equivalent: `bool addRenderPipelineFunctions(const MTL::RenderPipelineDescriptor*, NS::Error**)`
238    pub fn add_render_pipeline_functions(
239        &self,
240        descriptor: &crate::RenderPipelineDescriptor,
241    ) -> Result<(), mtl_foundation::Error> {
242        unsafe {
243            let mut error: *mut c_void = std::ptr::null_mut();
244            let result: bool = msg_send_2(
245                self.as_ptr(),
246                sel!(addRenderPipelineFunctionsWithDescriptor:error:),
247                descriptor.as_ptr(),
248                &mut error as *mut _,
249            );
250            if !result {
251                if !error.is_null() {
252                    return Err(
253                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
254                    );
255                }
256                return Err(mtl_foundation::Error::error(
257                    std::ptr::null_mut(),
258                    -1,
259                    std::ptr::null_mut(),
260                )
261                .expect("failed to create error"));
262            }
263            Ok(())
264        }
265    }
266
267    /// Add a function to the archive.
268    ///
269    /// C++ equivalent: `bool addFunction(const MTL::FunctionDescriptor*, const MTL::Library*, NS::Error**)`
270    pub fn add_function(
271        &self,
272        descriptor: *const c_void, // FunctionDescriptor not yet implemented
273        library: &crate::Library,
274    ) -> Result<(), mtl_foundation::Error> {
275        unsafe {
276            let mut error: *mut c_void = std::ptr::null_mut();
277            let result: bool = msg_send_3(
278                self.as_ptr(),
279                sel!(addFunctionWithDescriptor:library:error:),
280                descriptor,
281                library.as_ptr(),
282                &mut error as *mut _,
283            );
284            if !result {
285                if !error.is_null() {
286                    return Err(
287                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
288                    );
289                }
290                return Err(mtl_foundation::Error::error(
291                    std::ptr::null_mut(),
292                    -1,
293                    std::ptr::null_mut(),
294                )
295                .expect("failed to create error"));
296            }
297            Ok(())
298        }
299    }
300
301    /// Add a stitched library to the archive.
302    ///
303    /// C++ equivalent: `bool addLibrary(const MTL::StitchedLibraryDescriptor*, NS::Error**)`
304    pub fn add_library_ptr(
305        &self,
306        descriptor: *const c_void,
307    ) -> Result<(), mtl_foundation::Error> {
308        unsafe {
309            let mut error: *mut c_void = std::ptr::null_mut();
310            let result: bool = msg_send_2(
311                self.as_ptr(),
312                sel!(addLibraryWithDescriptor:error:),
313                descriptor,
314                &mut error as *mut _,
315            );
316            if !result {
317                if !error.is_null() {
318                    return Err(
319                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
320                    );
321                }
322                return Err(mtl_foundation::Error::error(
323                    std::ptr::null_mut(),
324                    -1,
325                    std::ptr::null_mut(),
326                )
327                .expect("failed to create error"));
328            }
329            Ok(())
330        }
331    }
332
333    /// Add mesh render pipeline functions to the archive.
334    ///
335    /// C++ equivalent: `bool addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor*, NS::Error**)`
336    pub fn add_mesh_render_pipeline_functions_ptr(
337        &self,
338        descriptor: *const c_void,
339    ) -> Result<(), mtl_foundation::Error> {
340        unsafe {
341            let mut error: *mut c_void = std::ptr::null_mut();
342            let result: bool = msg_send_2(
343                self.as_ptr(),
344                sel!(addMeshRenderPipelineFunctionsWithDescriptor:error:),
345                descriptor,
346                &mut error as *mut _,
347            );
348            if !result {
349                if !error.is_null() {
350                    return Err(
351                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
352                    );
353                }
354                return Err(mtl_foundation::Error::error(
355                    std::ptr::null_mut(),
356                    -1,
357                    std::ptr::null_mut(),
358                )
359                .expect("failed to create error"));
360            }
361            Ok(())
362        }
363    }
364
365    /// Add tile render pipeline functions to the archive.
366    ///
367    /// C++ equivalent: `bool addTileRenderPipelineFunctions(const MTL::TileRenderPipelineDescriptor*, NS::Error**)`
368    pub fn add_tile_render_pipeline_functions_ptr(
369        &self,
370        descriptor: *const c_void,
371    ) -> Result<(), mtl_foundation::Error> {
372        unsafe {
373            let mut error: *mut c_void = std::ptr::null_mut();
374            let result: bool = msg_send_2(
375                self.as_ptr(),
376                sel!(addTileRenderPipelineFunctionsWithDescriptor:error:),
377                descriptor,
378                &mut error as *mut _,
379            );
380            if !result {
381                if !error.is_null() {
382                    return Err(
383                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
384                    );
385                }
386                return Err(mtl_foundation::Error::error(
387                    std::ptr::null_mut(),
388                    -1,
389                    std::ptr::null_mut(),
390                )
391                .expect("failed to create error"));
392            }
393            Ok(())
394        }
395    }
396
397    /// Serialize the archive to a URL.
398    ///
399    /// C++ equivalent: `bool serializeToURL(const NS::URL*, NS::Error**)`
400    pub fn serialize_to_url(
401        &self,
402        url: &mtl_foundation::Url,
403    ) -> Result<(), mtl_foundation::Error> {
404        unsafe {
405            let mut error: *mut c_void = std::ptr::null_mut();
406            let result: bool = msg_send_2(
407                self.as_ptr(),
408                sel!(serializeToURL:error:),
409                url.as_ptr(),
410                &mut error as *mut _,
411            );
412            if !result {
413                if !error.is_null() {
414                    return Err(
415                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
416                    );
417                }
418                return Err(mtl_foundation::Error::error(
419                    std::ptr::null_mut(),
420                    -1,
421                    std::ptr::null_mut(),
422                )
423                .expect("failed to create error"));
424            }
425            Ok(())
426        }
427    }
428}
429
430impl Clone for BinaryArchive {
431    fn clone(&self) -> Self {
432        unsafe {
433            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
434        }
435        Self(self.0)
436    }
437}
438
439impl Drop for BinaryArchive {
440    fn drop(&mut self) {
441        unsafe {
442            msg_send_0::<()>(self.as_ptr(), sel!(release));
443        }
444    }
445}
446
447impl Referencing for BinaryArchive {
448    #[inline]
449    fn as_ptr(&self) -> *const c_void {
450        self.0.as_ptr()
451    }
452}
453
454unsafe impl Send for BinaryArchive {}
455unsafe impl Sync for BinaryArchive {}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_binary_archive_descriptor_creation() {
463        let descriptor = BinaryArchiveDescriptor::new();
464        assert!(descriptor.is_some());
465    }
466
467    #[test]
468    fn test_binary_archive_descriptor_size() {
469        assert_eq!(
470            std::mem::size_of::<BinaryArchiveDescriptor>(),
471            std::mem::size_of::<*mut c_void>()
472        );
473    }
474
475    #[test]
476    fn test_binary_archive_size() {
477        assert_eq!(
478            std::mem::size_of::<BinaryArchive>(),
479            std::mem::size_of::<*mut c_void>()
480        );
481    }
482
483    #[test]
484    fn test_binary_archive_error_values() {
485        assert_eq!(BinaryArchiveError::NONE.0, 0);
486        assert_eq!(BinaryArchiveError::INVALID_FILE.0, 1);
487        assert_eq!(BinaryArchiveError::UNEXPECTED_ELEMENT.0, 2);
488        assert_eq!(BinaryArchiveError::COMPILATION_FAILURE.0, 3);
489        assert_eq!(BinaryArchiveError::INTERNAL_ERROR.0, 4);
490    }
491}