Skip to main content

mtl_gpu/library/
library.rs

1//! A collection of compiled shader functions.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
8
9use crate::enums::LibraryType;
10
11use super::{
12    Function, FunctionConstantValues, FunctionDescriptor, FunctionReflection,
13    IntersectionFunctionDescriptor,
14};
15
16/// A collection of compiled shader functions.
17///
18/// C++ equivalent: `MTL::Library`
19#[repr(transparent)]
20pub struct Library(pub(crate) NonNull<c_void>);
21
22impl Library {
23    /// Create a Library from a raw pointer.
24    ///
25    /// # Safety
26    ///
27    /// The pointer must be a valid Metal library object.
28    #[inline]
29    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30        NonNull::new(ptr).map(Self)
31    }
32
33    /// Get the raw pointer to the library.
34    #[inline]
35    pub fn as_raw(&self) -> *mut c_void {
36        self.0.as_ptr()
37    }
38
39    // =========================================================================
40    // Properties
41    // =========================================================================
42
43    /// Get the label for this library.
44    ///
45    /// C++ equivalent: `NS::String* label() const`
46    pub fn label(&self) -> Option<String> {
47        unsafe {
48            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
49            if ptr.is_null() {
50                return None;
51            }
52            let utf8_ptr: *const std::ffi::c_char =
53                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
54            if utf8_ptr.is_null() {
55                return None;
56            }
57            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
58            Some(c_str.to_string_lossy().into_owned())
59        }
60    }
61
62    /// Set the label for this library.
63    ///
64    /// C++ equivalent: `void setLabel(const NS::String*)`
65    pub fn set_label(&self, label: &str) {
66        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
67            unsafe {
68                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
69            }
70        }
71    }
72
73    /// Get the device that created this library.
74    ///
75    /// C++ equivalent: `Device* device() const`
76    pub fn device(&self) -> crate::Device {
77        unsafe {
78            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
79            if ptr.is_null() {
80                panic!("library has no device");
81            }
82            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
83            crate::Device::from_raw(ptr).expect("library has no device")
84        }
85    }
86
87    /// Get the library type.
88    ///
89    /// C++ equivalent: `LibraryType type() const`
90    #[inline]
91    pub fn library_type(&self) -> LibraryType {
92        unsafe { msg_send_0(self.as_ptr(), sel!(type)) }
93    }
94
95    /// Get the install name (for dynamic libraries).
96    ///
97    /// C++ equivalent: `NS::String* installName() const`
98    pub fn install_name(&self) -> Option<String> {
99        unsafe {
100            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(installName));
101            if ptr.is_null() {
102                return None;
103            }
104            let utf8_ptr: *const std::ffi::c_char =
105                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
106            if utf8_ptr.is_null() {
107                return None;
108            }
109            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
110            Some(c_str.to_string_lossy().into_owned())
111        }
112    }
113
114    // =========================================================================
115    // Function Retrieval
116    // =========================================================================
117
118    /// Get a function by name.
119    ///
120    /// C++ equivalent: `Function* newFunction(const NS::String* name)`
121    pub fn new_function_with_name(&self, name: &str) -> Option<Function> {
122        let ns_name = mtl_foundation::String::from_str(name)?;
123        unsafe {
124            let ptr: *mut c_void =
125                msg_send_1(self.as_ptr(), sel!(newFunctionWithName:), ns_name.as_ptr());
126            Function::from_raw(ptr)
127        }
128    }
129
130    /// Get a function by name with constant values.
131    ///
132    /// C++ equivalent: `Function* newFunction(const NS::String*, const FunctionConstantValues*, NS::Error**)`
133    ///
134    /// # Safety
135    ///
136    /// The constant_values pointer must be valid if not null.
137    pub unsafe fn new_function_with_name_and_constants(
138        &self,
139        name: &str,
140        constant_values: *const c_void,
141    ) -> Result<Function, mtl_foundation::Error> {
142        let ns_name = mtl_foundation::String::from_str(name).ok_or_else(|| {
143            mtl_foundation::Error::error(std::ptr::null_mut(), 0, std::ptr::null_mut())
144                .expect("failed to create error for invalid string")
145        })?;
146
147        let mut error: *mut c_void = std::ptr::null_mut();
148        unsafe {
149            let ptr: *mut c_void = mtl_sys::msg_send_3(
150                self.as_ptr(),
151                sel!(newFunctionWithName: constantValues: error:),
152                ns_name.as_ptr(),
153                constant_values,
154                &mut error as *mut _,
155            );
156
157            if ptr.is_null() {
158                if !error.is_null() {
159                    let _: *mut c_void = msg_send_0(error, sel!(retain));
160                    return Err(mtl_foundation::Error::from_ptr(error).unwrap());
161                }
162                return Err(mtl_foundation::Error::error(
163                    std::ptr::null_mut(),
164                    -1,
165                    std::ptr::null_mut(),
166                )
167                .expect("failed to create error object"));
168            }
169
170            Ok(Function::from_raw(ptr).unwrap())
171        }
172    }
173
174    /// Get all function names in the library.
175    ///
176    /// C++ equivalent: `NS::Array* functionNames() const`
177    pub fn function_names(&self) -> Vec<String> {
178        unsafe {
179            let array: *mut c_void = msg_send_0(self.as_ptr(), sel!(functionNames));
180            if array.is_null() {
181                return Vec::new();
182            }
183
184            let count: UInteger = msg_send_0(array, sel!(count));
185            let mut names = Vec::with_capacity(count as usize);
186
187            for i in 0..count {
188                let obj: *mut c_void = msg_send_1(array, sel!(objectAtIndex:), i);
189                if !obj.is_null() {
190                    let utf8_ptr: *const std::ffi::c_char =
191                        mtl_sys::msg_send_0(obj as *const c_void, sel!(UTF8String));
192                    if !utf8_ptr.is_null() {
193                        let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
194                        names.push(c_str.to_string_lossy().into_owned());
195                    }
196                }
197            }
198
199            names
200        }
201    }
202
203    /// Create a function with a descriptor.
204    ///
205    /// C++ equivalent: `Function* newFunction(const FunctionDescriptor*, NS::Error**)`
206    pub fn new_function_with_descriptor(
207        &self,
208        descriptor: &FunctionDescriptor,
209    ) -> Result<Function, mtl_foundation::Error> {
210        unsafe {
211            let mut error: *mut c_void = std::ptr::null_mut();
212            let ptr: *mut c_void = msg_send_2(
213                self.as_ptr(),
214                sel!(newFunctionWithDescriptor: error:),
215                descriptor.as_ptr(),
216                &mut error as *mut _,
217            );
218
219            if ptr.is_null() {
220                if !error.is_null() {
221                    return Err(
222                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
223                    );
224                }
225                return Err(mtl_foundation::Error::error(
226                    std::ptr::null_mut(),
227                    -1,
228                    std::ptr::null_mut(),
229                )
230                .expect("failed to create error object"));
231            }
232
233            Ok(Function::from_raw(ptr).expect("failed to create function"))
234        }
235    }
236
237    /// Create an intersection function with a descriptor.
238    ///
239    /// C++ equivalent: `Function* newIntersectionFunction(const IntersectionFunctionDescriptor*, NS::Error**)`
240    pub fn new_intersection_function(
241        &self,
242        descriptor: &IntersectionFunctionDescriptor,
243    ) -> Result<Function, mtl_foundation::Error> {
244        unsafe {
245            let mut error: *mut c_void = std::ptr::null_mut();
246            let ptr: *mut c_void = msg_send_2(
247                self.as_ptr(),
248                sel!(newIntersectionFunctionWithDescriptor: error:),
249                descriptor.as_ptr(),
250                &mut error as *mut _,
251            );
252
253            if ptr.is_null() {
254                if !error.is_null() {
255                    return Err(
256                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
257                    );
258                }
259                return Err(mtl_foundation::Error::error(
260                    std::ptr::null_mut(),
261                    -1,
262                    std::ptr::null_mut(),
263                )
264                .expect("failed to create error object"));
265            }
266
267            Ok(Function::from_raw(ptr).expect("failed to create intersection function"))
268        }
269    }
270
271    /// Get reflection information for a function by name.
272    ///
273    /// C++ equivalent: `FunctionReflection* reflectionForFunction(const NS::String*)`
274    pub fn reflection_for_function(&self, name: &str) -> Option<FunctionReflection> {
275        let ns_name = mtl_foundation::String::from_str(name)?;
276        unsafe {
277            let ptr: *mut c_void = msg_send_1(
278                self.as_ptr(),
279                sel!(reflectionForFunctionWithName:),
280                ns_name.as_ptr(),
281            );
282            if ptr.is_null() {
283                return None;
284            }
285            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
286            FunctionReflection::from_raw(ptr)
287        }
288    }
289
290    // =========================================================================
291    // Async Function Creation
292    // =========================================================================
293
294    /// Create a function with constant values asynchronously.
295    ///
296    /// C++ equivalent: `void newFunction(const NS::String*, const FunctionConstantValues*, void (^)(Function*, Error*))`
297    pub fn new_function_with_name_and_constants_async<F>(
298        &self,
299        name: &str,
300        constant_values: &FunctionConstantValues,
301        completion_handler: F,
302    ) where
303        F: Fn(Option<Function>, Option<mtl_foundation::Error>) + Send + 'static,
304    {
305        let Some(ns_name) = mtl_foundation::String::from_str(name) else {
306            completion_handler(
307                None,
308                mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut()),
309            );
310            return;
311        };
312
313        let block =
314            mtl_sys::TwoArgBlock::from_fn(move |fn_ptr: *mut c_void, err_ptr: *mut c_void| {
315                let function = if fn_ptr.is_null() {
316                    None
317                } else {
318                    unsafe { Function::from_raw(fn_ptr) }
319                };
320
321                let error = if err_ptr.is_null() {
322                    None
323                } else {
324                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
325                };
326
327                completion_handler(function, error);
328            });
329
330        unsafe {
331            mtl_sys::msg_send_3::<(), *const c_void, *const c_void, *const c_void>(
332                self.as_ptr(),
333                sel!(newFunctionWithName:constantValues:completionHandler:),
334                ns_name.as_ptr(),
335                constant_values.as_ptr(),
336                block.as_ptr(),
337            );
338        }
339
340        std::mem::forget(block);
341    }
342
343    /// Create a function with a descriptor asynchronously.
344    ///
345    /// C++ equivalent: `void newFunction(const FunctionDescriptor*, void (^)(Function*, Error*))`
346    pub fn new_function_with_descriptor_async<F>(
347        &self,
348        descriptor: &FunctionDescriptor,
349        completion_handler: F,
350    ) where
351        F: Fn(Option<Function>, Option<mtl_foundation::Error>) + Send + 'static,
352    {
353        let block =
354            mtl_sys::TwoArgBlock::from_fn(move |fn_ptr: *mut c_void, err_ptr: *mut c_void| {
355                let function = if fn_ptr.is_null() {
356                    None
357                } else {
358                    unsafe { Function::from_raw(fn_ptr) }
359                };
360
361                let error = if err_ptr.is_null() {
362                    None
363                } else {
364                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
365                };
366
367                completion_handler(function, error);
368            });
369
370        unsafe {
371            mtl_sys::msg_send_2::<(), *const c_void, *const c_void>(
372                self.as_ptr(),
373                sel!(newFunctionWithDescriptor:completionHandler:),
374                descriptor.as_ptr(),
375                block.as_ptr(),
376            );
377        }
378
379        std::mem::forget(block);
380    }
381
382    /// Create an intersection function with a descriptor asynchronously.
383    ///
384    /// C++ equivalent: `void newIntersectionFunction(const IntersectionFunctionDescriptor*, void (^)(Function*, Error*))`
385    pub fn new_intersection_function_async<F>(
386        &self,
387        descriptor: &IntersectionFunctionDescriptor,
388        completion_handler: F,
389    ) where
390        F: Fn(Option<Function>, Option<mtl_foundation::Error>) + Send + 'static,
391    {
392        let block =
393            mtl_sys::TwoArgBlock::from_fn(move |fn_ptr: *mut c_void, err_ptr: *mut c_void| {
394                let function = if fn_ptr.is_null() {
395                    None
396                } else {
397                    unsafe { Function::from_raw(fn_ptr) }
398                };
399
400                let error = if err_ptr.is_null() {
401                    None
402                } else {
403                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
404                };
405
406                completion_handler(function, error);
407            });
408
409        unsafe {
410            mtl_sys::msg_send_2::<(), *const c_void, *const c_void>(
411                self.as_ptr(),
412                sel!(newIntersectionFunctionWithDescriptor:completionHandler:),
413                descriptor.as_ptr(),
414                block.as_ptr(),
415            );
416        }
417
418        std::mem::forget(block);
419    }
420}
421
422impl Clone for Library {
423    fn clone(&self) -> Self {
424        unsafe {
425            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
426        }
427        Self(self.0)
428    }
429}
430
431impl Drop for Library {
432    fn drop(&mut self) {
433        unsafe {
434            msg_send_0::<()>(self.as_ptr(), sel!(release));
435        }
436    }
437}
438
439impl Referencing for Library {
440    #[inline]
441    fn as_ptr(&self) -> *const c_void {
442        self.0.as_ptr()
443    }
444}
445
446unsafe impl Send for Library {}
447unsafe impl Sync for Library {}
448
449impl std::fmt::Debug for Library {
450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451        f.debug_struct("Library")
452            .field("label", &self.label())
453            .field("library_type", &self.library_type())
454            .finish()
455    }
456}