Skip to main content

mtl_gpu/device/
library.rs

1//! Device library creation methods.
2//!
3//! Corresponds to library creation methods in `Metal/MTLDevice.hpp`.
4
5use std::ffi::c_void;
6
7use mtl_foundation::Referencing;
8use mtl_sys::{msg_send_0, msg_send_2, msg_send_3, sel};
9
10use super::Device;
11use crate::library::{CompileOptions, Library};
12
13impl Device {
14    // =========================================================================
15    // Library Creation
16    // =========================================================================
17
18    /// Create the default library from the app's main bundle.
19    ///
20    /// C++ equivalent: `Library* newDefaultLibrary()`
21    pub fn new_default_library(&self) -> Option<Library> {
22        unsafe {
23            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(newDefaultLibrary));
24            Library::from_raw(ptr)
25        }
26    }
27
28    /// Create the default library from a bundle.
29    ///
30    /// C++ equivalent: `Library* newDefaultLibrary(NS::Bundle*, NS::Error**)`
31    ///
32    /// # Safety
33    ///
34    /// The bundle pointer must be valid.
35    pub unsafe fn new_default_library_with_bundle(
36        &self,
37        bundle: *const c_void,
38    ) -> Result<Library, mtl_foundation::Error> {
39        let mut error: *mut c_void = std::ptr::null_mut();
40        unsafe {
41            let ptr: *mut c_void = msg_send_2(
42                self.as_ptr(),
43                sel!(newDefaultLibraryWithBundle: error:),
44                bundle,
45                &mut error as *mut _,
46            );
47
48            if ptr.is_null() {
49                if !error.is_null() {
50                    let _: *mut c_void = msg_send_0(error, sel!(retain));
51                    return Err(mtl_foundation::Error::from_ptr(error)
52                        .expect("error pointer should be valid"));
53                }
54                return Err(mtl_foundation::Error::error(
55                    std::ptr::null_mut(),
56                    -1,
57                    std::ptr::null_mut(),
58                )
59                .expect("failed to create error object"));
60            }
61
62            Ok(Library::from_raw(ptr).expect("library should be valid"))
63        }
64    }
65
66    /// Create a library from source code.
67    ///
68    /// C++ equivalent: `Library* newLibrary(const NS::String* source, const CompileOptions* options, NS::Error** error)`
69    pub fn new_library_with_source(
70        &self,
71        source: &str,
72        options: Option<&CompileOptions>,
73    ) -> Result<Library, mtl_foundation::Error> {
74        let ns_source = mtl_foundation::String::from_str(source).ok_or_else(|| {
75            mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut())
76                .expect("failed to create error for invalid string")
77        })?;
78
79        let mut error: *mut c_void = std::ptr::null_mut();
80        unsafe {
81            let ptr: *mut c_void = msg_send_3(
82                self.as_ptr(),
83                sel!(newLibraryWithSource: options: error:),
84                ns_source.as_ptr(),
85                options.map_or(std::ptr::null(), |o| o.as_ptr()),
86                &mut error as *mut _,
87            );
88
89            if ptr.is_null() {
90                if !error.is_null() {
91                    let _: *mut c_void = msg_send_0(error, sel!(retain));
92                    return Err(mtl_foundation::Error::from_ptr(error)
93                        .expect("error pointer should be valid"));
94                }
95                return Err(mtl_foundation::Error::error(
96                    std::ptr::null_mut(),
97                    -1,
98                    std::ptr::null_mut(),
99                )
100                .expect("failed to create error object"));
101            }
102
103            Ok(Library::from_raw(ptr).expect("library should be valid"))
104        }
105    }
106
107    /// Create a library from pre-compiled binary data.
108    ///
109    /// C++ equivalent: `Library* newLibrary(dispatch_data_t data, NS::Error** error)`
110    ///
111    /// # Safety
112    ///
113    /// The data pointer must be valid dispatch_data_t.
114    pub unsafe fn new_library_with_data(
115        &self,
116        data: *const c_void,
117    ) -> Result<Library, mtl_foundation::Error> {
118        let mut error: *mut c_void = std::ptr::null_mut();
119        unsafe {
120            let ptr: *mut c_void = msg_send_2(
121                self.as_ptr(),
122                sel!(newLibraryWithData: error:),
123                data,
124                &mut error as *mut _,
125            );
126
127            if ptr.is_null() {
128                if !error.is_null() {
129                    let _: *mut c_void = msg_send_0(error, sel!(retain));
130                    return Err(mtl_foundation::Error::from_ptr(error)
131                        .expect("error pointer should be valid"));
132                }
133                return Err(mtl_foundation::Error::error(
134                    std::ptr::null_mut(),
135                    -1,
136                    std::ptr::null_mut(),
137                )
138                .expect("failed to create error object"));
139            }
140
141            Ok(Library::from_raw(ptr).expect("library should be valid"))
142        }
143    }
144
145    /// Create a library from a file URL.
146    ///
147    /// C++ equivalent: `Library* newLibrary(const NS::URL* url, NS::Error** error)`
148    ///
149    /// # Safety
150    ///
151    /// The URL pointer must be valid.
152    pub unsafe fn new_library_with_url(
153        &self,
154        url: *const c_void,
155    ) -> Result<Library, mtl_foundation::Error> {
156        let mut error: *mut c_void = std::ptr::null_mut();
157        unsafe {
158            let ptr: *mut c_void = msg_send_2(
159                self.as_ptr(),
160                sel!(newLibraryWithURL: error:),
161                url,
162                &mut error as *mut _,
163            );
164
165            if ptr.is_null() {
166                if !error.is_null() {
167                    let _: *mut c_void = msg_send_0(error, sel!(retain));
168                    return Err(mtl_foundation::Error::from_ptr(error)
169                        .expect("error pointer should be valid"));
170                }
171                return Err(mtl_foundation::Error::error(
172                    std::ptr::null_mut(),
173                    -1,
174                    std::ptr::null_mut(),
175                )
176                .expect("failed to create error object"));
177            }
178
179            Ok(Library::from_raw(ptr).expect("library should be valid"))
180        }
181    }
182
183    /// Create a library from a stitched library descriptor.
184    ///
185    /// C++ equivalent: `Library* newLibrary(const StitchedLibraryDescriptor*, NS::Error**)`
186    ///
187    /// # Safety
188    ///
189    /// The descriptor pointer must be valid.
190    pub unsafe fn new_library_with_stitched_descriptor(
191        &self,
192        descriptor: *const c_void,
193    ) -> Result<Library, mtl_foundation::Error> {
194        let mut error: *mut c_void = std::ptr::null_mut();
195        unsafe {
196            let ptr: *mut c_void = msg_send_2(
197                self.as_ptr(),
198                sel!(newLibraryWithStitchedDescriptor: error:),
199                descriptor,
200                &mut error as *mut _,
201            );
202
203            if ptr.is_null() {
204                if !error.is_null() {
205                    let _: *mut c_void = msg_send_0(error, sel!(retain));
206                    return Err(mtl_foundation::Error::from_ptr(error)
207                        .expect("error pointer should be valid"));
208                }
209                return Err(mtl_foundation::Error::error(
210                    std::ptr::null_mut(),
211                    -1,
212                    std::ptr::null_mut(),
213                )
214                .expect("failed to create error object"));
215            }
216
217            Ok(Library::from_raw(ptr).expect("library should be valid"))
218        }
219    }
220
221    // =========================================================================
222    // Async Library Creation
223    // =========================================================================
224
225    /// Create a library from source code asynchronously.
226    ///
227    /// C++ equivalent: `void newLibrary(const NS::String* source, const CompileOptions* options, NewLibraryCompletionHandler)`
228    ///
229    /// The completion handler is called with the library and any error that occurred.
230    pub fn new_library_with_source_async<F>(
231        &self,
232        source: &str,
233        options: Option<&CompileOptions>,
234        completion_handler: F,
235    ) where
236        F: Fn(Option<Library>, Option<mtl_foundation::Error>) + Send + 'static,
237    {
238        let Some(ns_source) = mtl_foundation::String::from_str(source) else {
239            // Call completion handler with error
240            completion_handler(
241                None,
242                mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut()),
243            );
244            return;
245        };
246
247        let block =
248            mtl_sys::TwoArgBlock::from_fn(move |lib_ptr: *mut c_void, err_ptr: *mut c_void| {
249                let library = if lib_ptr.is_null() {
250                    None
251                } else {
252                    unsafe { Library::from_raw(lib_ptr) }
253                };
254
255                let error = if err_ptr.is_null() {
256                    None
257                } else {
258                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
259                };
260
261                completion_handler(library, error);
262            });
263
264        unsafe {
265            msg_send_3::<(), *const c_void, *const c_void, *const c_void>(
266                self.as_ptr(),
267                sel!(newLibraryWithSource:options:completionHandler:),
268                ns_source.as_ptr(),
269                options.map_or(std::ptr::null(), |o| o.as_ptr()),
270                block.as_ptr(),
271            );
272        }
273
274        std::mem::forget(block);
275    }
276
277    /// Create a library from a stitched library descriptor asynchronously.
278    ///
279    /// C++ equivalent: `void newLibrary(const StitchedLibraryDescriptor*, NewLibraryCompletionHandler)`
280    ///
281    /// # Safety
282    ///
283    /// The descriptor pointer must be valid.
284    pub unsafe fn new_library_with_stitched_descriptor_async<F>(
285        &self,
286        descriptor: *const c_void,
287        completion_handler: F,
288    ) where
289        F: Fn(Option<Library>, Option<mtl_foundation::Error>) + Send + 'static,
290    {
291        let block =
292            mtl_sys::TwoArgBlock::from_fn(move |lib_ptr: *mut c_void, err_ptr: *mut c_void| {
293                let library = if lib_ptr.is_null() {
294                    None
295                } else {
296                    unsafe { Library::from_raw(lib_ptr) }
297                };
298
299                let error = if err_ptr.is_null() {
300                    None
301                } else {
302                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
303                };
304
305                completion_handler(library, error);
306            });
307
308        unsafe {
309            msg_send_2::<(), *const c_void, *const c_void>(
310                self.as_ptr(),
311                sel!(newLibraryWithStitchedDescriptor:completionHandler:),
312                descriptor,
313                block.as_ptr(),
314            );
315        }
316
317        std::mem::forget(block);
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::device::system_default;
325
326    #[test]
327    fn test_new_library_with_source() {
328        let device = system_default().expect("no Metal device");
329
330        let source = r#"
331            #include <metal_stdlib>
332            using namespace metal;
333
334            kernel void test_kernel(device float* data [[buffer(0)]],
335                                   uint id [[thread_position_in_grid]]) {
336                data[id] = data[id] * 2.0;
337            }
338        "#;
339
340        let result = device.new_library_with_source(source, None);
341        assert!(
342            result.is_ok(),
343            "Failed to compile shader: {:?}",
344            result.err()
345        );
346
347        let library = result.unwrap();
348        let names = library.function_names();
349        assert!(names.contains(&"test_kernel".to_string()));
350    }
351
352    #[test]
353    fn test_new_library_with_options() {
354        let device = system_default().expect("no Metal device");
355
356        let source = r#"
357            #include <metal_stdlib>
358            using namespace metal;
359
360            vertex float4 vertex_main(uint vid [[vertex_id]]) {
361                return float4(0.0);
362            }
363        "#;
364
365        let options = CompileOptions::new().expect("failed to create options");
366        options.set_fast_math_enabled(true);
367
368        let result = device.new_library_with_source(source, Some(&options));
369        assert!(result.is_ok());
370    }
371}