1use std::ffi::c_void;
6
7use mtl_foundation::Referencing;
8use mtl_sys::{msg_send_0, msg_send_2, msg_send_3, sel};
9
10use super::Device;
11use crate::library::{CompileOptions, Library};
12
13impl Device {
14 pub fn new_default_library(&self) -> Option<Library> {
22 unsafe {
23 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(newDefaultLibrary));
24 Library::from_raw(ptr)
25 }
26 }
27
28 pub unsafe fn new_default_library_with_bundle(
36 &self,
37 bundle: *const c_void,
38 ) -> Result<Library, mtl_foundation::Error> {
39 let mut error: *mut c_void = std::ptr::null_mut();
40 unsafe {
41 let ptr: *mut c_void = msg_send_2(
42 self.as_ptr(),
43 sel!(newDefaultLibraryWithBundle: error:),
44 bundle,
45 &mut error as *mut _,
46 );
47
48 if ptr.is_null() {
49 if !error.is_null() {
50 let _: *mut c_void = msg_send_0(error, sel!(retain));
51 return Err(mtl_foundation::Error::from_ptr(error)
52 .expect("error pointer should be valid"));
53 }
54 return Err(mtl_foundation::Error::error(
55 std::ptr::null_mut(),
56 -1,
57 std::ptr::null_mut(),
58 )
59 .expect("failed to create error object"));
60 }
61
62 Ok(Library::from_raw(ptr).expect("library should be valid"))
63 }
64 }
65
66 pub fn new_library_with_source(
70 &self,
71 source: &str,
72 options: Option<&CompileOptions>,
73 ) -> Result<Library, mtl_foundation::Error> {
74 let ns_source = mtl_foundation::String::from_str(source).ok_or_else(|| {
75 mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut())
76 .expect("failed to create error for invalid string")
77 })?;
78
79 let mut error: *mut c_void = std::ptr::null_mut();
80 unsafe {
81 let ptr: *mut c_void = msg_send_3(
82 self.as_ptr(),
83 sel!(newLibraryWithSource: options: error:),
84 ns_source.as_ptr(),
85 options.map_or(std::ptr::null(), |o| o.as_ptr()),
86 &mut error as *mut _,
87 );
88
89 if ptr.is_null() {
90 if !error.is_null() {
91 let _: *mut c_void = msg_send_0(error, sel!(retain));
92 return Err(mtl_foundation::Error::from_ptr(error)
93 .expect("error pointer should be valid"));
94 }
95 return Err(mtl_foundation::Error::error(
96 std::ptr::null_mut(),
97 -1,
98 std::ptr::null_mut(),
99 )
100 .expect("failed to create error object"));
101 }
102
103 Ok(Library::from_raw(ptr).expect("library should be valid"))
104 }
105 }
106
107 pub unsafe fn new_library_with_data(
115 &self,
116 data: *const c_void,
117 ) -> Result<Library, mtl_foundation::Error> {
118 let mut error: *mut c_void = std::ptr::null_mut();
119 unsafe {
120 let ptr: *mut c_void = msg_send_2(
121 self.as_ptr(),
122 sel!(newLibraryWithData: error:),
123 data,
124 &mut error as *mut _,
125 );
126
127 if ptr.is_null() {
128 if !error.is_null() {
129 let _: *mut c_void = msg_send_0(error, sel!(retain));
130 return Err(mtl_foundation::Error::from_ptr(error)
131 .expect("error pointer should be valid"));
132 }
133 return Err(mtl_foundation::Error::error(
134 std::ptr::null_mut(),
135 -1,
136 std::ptr::null_mut(),
137 )
138 .expect("failed to create error object"));
139 }
140
141 Ok(Library::from_raw(ptr).expect("library should be valid"))
142 }
143 }
144
145 pub unsafe fn new_library_with_url(
153 &self,
154 url: *const c_void,
155 ) -> Result<Library, mtl_foundation::Error> {
156 let mut error: *mut c_void = std::ptr::null_mut();
157 unsafe {
158 let ptr: *mut c_void = msg_send_2(
159 self.as_ptr(),
160 sel!(newLibraryWithURL: error:),
161 url,
162 &mut error as *mut _,
163 );
164
165 if ptr.is_null() {
166 if !error.is_null() {
167 let _: *mut c_void = msg_send_0(error, sel!(retain));
168 return Err(mtl_foundation::Error::from_ptr(error)
169 .expect("error pointer should be valid"));
170 }
171 return Err(mtl_foundation::Error::error(
172 std::ptr::null_mut(),
173 -1,
174 std::ptr::null_mut(),
175 )
176 .expect("failed to create error object"));
177 }
178
179 Ok(Library::from_raw(ptr).expect("library should be valid"))
180 }
181 }
182
183 pub unsafe fn new_library_with_stitched_descriptor(
191 &self,
192 descriptor: *const c_void,
193 ) -> Result<Library, mtl_foundation::Error> {
194 let mut error: *mut c_void = std::ptr::null_mut();
195 unsafe {
196 let ptr: *mut c_void = msg_send_2(
197 self.as_ptr(),
198 sel!(newLibraryWithStitchedDescriptor: error:),
199 descriptor,
200 &mut error as *mut _,
201 );
202
203 if ptr.is_null() {
204 if !error.is_null() {
205 let _: *mut c_void = msg_send_0(error, sel!(retain));
206 return Err(mtl_foundation::Error::from_ptr(error)
207 .expect("error pointer should be valid"));
208 }
209 return Err(mtl_foundation::Error::error(
210 std::ptr::null_mut(),
211 -1,
212 std::ptr::null_mut(),
213 )
214 .expect("failed to create error object"));
215 }
216
217 Ok(Library::from_raw(ptr).expect("library should be valid"))
218 }
219 }
220
221 pub fn new_library_with_source_async<F>(
231 &self,
232 source: &str,
233 options: Option<&CompileOptions>,
234 completion_handler: F,
235 ) where
236 F: Fn(Option<Library>, Option<mtl_foundation::Error>) + Send + 'static,
237 {
238 let Some(ns_source) = mtl_foundation::String::from_str(source) else {
239 completion_handler(
241 None,
242 mtl_foundation::Error::error(std::ptr::null_mut(), -1, std::ptr::null_mut()),
243 );
244 return;
245 };
246
247 let block =
248 mtl_sys::TwoArgBlock::from_fn(move |lib_ptr: *mut c_void, err_ptr: *mut c_void| {
249 let library = if lib_ptr.is_null() {
250 None
251 } else {
252 unsafe { Library::from_raw(lib_ptr) }
253 };
254
255 let error = if err_ptr.is_null() {
256 None
257 } else {
258 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
259 };
260
261 completion_handler(library, error);
262 });
263
264 unsafe {
265 msg_send_3::<(), *const c_void, *const c_void, *const c_void>(
266 self.as_ptr(),
267 sel!(newLibraryWithSource:options:completionHandler:),
268 ns_source.as_ptr(),
269 options.map_or(std::ptr::null(), |o| o.as_ptr()),
270 block.as_ptr(),
271 );
272 }
273
274 std::mem::forget(block);
275 }
276
277 pub unsafe fn new_library_with_stitched_descriptor_async<F>(
285 &self,
286 descriptor: *const c_void,
287 completion_handler: F,
288 ) where
289 F: Fn(Option<Library>, Option<mtl_foundation::Error>) + Send + 'static,
290 {
291 let block =
292 mtl_sys::TwoArgBlock::from_fn(move |lib_ptr: *mut c_void, err_ptr: *mut c_void| {
293 let library = if lib_ptr.is_null() {
294 None
295 } else {
296 unsafe { Library::from_raw(lib_ptr) }
297 };
298
299 let error = if err_ptr.is_null() {
300 None
301 } else {
302 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
303 };
304
305 completion_handler(library, error);
306 });
307
308 unsafe {
309 msg_send_2::<(), *const c_void, *const c_void>(
310 self.as_ptr(),
311 sel!(newLibraryWithStitchedDescriptor:completionHandler:),
312 descriptor,
313 block.as_ptr(),
314 );
315 }
316
317 std::mem::forget(block);
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::device::system_default;
325
326 #[test]
327 fn test_new_library_with_source() {
328 let device = system_default().expect("no Metal device");
329
330 let source = r#"
331 #include <metal_stdlib>
332 using namespace metal;
333
334 kernel void test_kernel(device float* data [[buffer(0)]],
335 uint id [[thread_position_in_grid]]) {
336 data[id] = data[id] * 2.0;
337 }
338 "#;
339
340 let result = device.new_library_with_source(source, None);
341 assert!(
342 result.is_ok(),
343 "Failed to compile shader: {:?}",
344 result.err()
345 );
346
347 let library = result.unwrap();
348 let names = library.function_names();
349 assert!(names.contains(&"test_kernel".to_string()));
350 }
351
352 #[test]
353 fn test_new_library_with_options() {
354 let device = system_default().expect("no Metal device");
355
356 let source = r#"
357 #include <metal_stdlib>
358 using namespace metal;
359
360 vertex float4 vertex_main(uint vid [[vertex_id]]) {
361 return float4(0.0);
362 }
363 "#;
364
365 let options = CompileOptions::new().expect("failed to create options");
366 options.set_fast_math_enabled(true);
367
368 let result = device.new_library_with_source(source, Some(&options));
369 assert!(result.is_ok());
370 }
371}