1use std::ffi::c_void;
6
7use mtl_foundation::Referencing;
8use mtl_sys::{msg_send_0, msg_send_2, msg_send_3, sel};
9
10use super::Device;
11use crate::error::ValidationError;
12use crate::library::Function;
13use crate::pipeline::{
14 ComputePipelineDescriptor, ComputePipelineReflection, ComputePipelineState,
15 MeshRenderPipelineDescriptor, RenderPipelineDescriptor, RenderPipelineReflection,
16 RenderPipelineState, TileRenderPipelineDescriptor,
17};
18
19impl Device {
20 pub unsafe fn new_render_pipeline_state(
32 &self,
33 descriptor: *const c_void,
34 ) -> Result<RenderPipelineState, mtl_foundation::Error> {
35 let mut error: *mut c_void = std::ptr::null_mut();
36 unsafe {
37 let ptr: *mut c_void = msg_send_2(
38 self.as_ptr(),
39 sel!(newRenderPipelineStateWithDescriptor: error:),
40 descriptor,
41 &mut error as *mut _,
42 );
43
44 if ptr.is_null() {
45 if !error.is_null() {
46 let _: *mut c_void = msg_send_0(error, sel!(retain));
47 return Err(mtl_foundation::Error::from_ptr(error)
48 .expect("error pointer should be valid"));
49 }
50 return Err(mtl_foundation::Error::error(
51 std::ptr::null_mut(),
52 -1,
53 std::ptr::null_mut(),
54 )
55 .expect("failed to create error object"));
56 }
57
58 Ok(RenderPipelineState::from_raw(ptr).expect("render pipeline state should be valid"))
59 }
60 }
61
62 pub fn new_render_pipeline_state_with_descriptor(
85 &self,
86 descriptor: &RenderPipelineDescriptor,
87 ) -> Result<RenderPipelineState, ValidationError> {
88 if descriptor.vertex_function().is_none() {
90 return Err(ValidationError::MissingVertexFunction);
91 }
92
93 let sample_count = descriptor.raster_sample_count();
95 if sample_count > 1 && !self.supports_texture_sample_count(sample_count) {
96 return Err(ValidationError::UnsupportedRasterSampleCount(sample_count));
97 }
98
99 unsafe {
101 self.new_render_pipeline_state(descriptor.as_ptr())
102 .map_err(ValidationError::from)
103 }
104 }
105
106 pub unsafe fn new_render_pipeline_state_with_reflection(
114 &self,
115 descriptor: *const c_void,
116 options: crate::enums::PipelineOption,
117 reflection: *mut *mut c_void,
118 ) -> Result<RenderPipelineState, mtl_foundation::Error> {
119 let mut error: *mut c_void = std::ptr::null_mut();
120 unsafe {
121 let ptr: *mut c_void = mtl_sys::msg_send_4(
122 self.as_ptr(),
123 sel!(newRenderPipelineStateWithDescriptor: options: reflection: error:),
124 descriptor,
125 options,
126 reflection,
127 &mut error as *mut _,
128 );
129
130 if ptr.is_null() {
131 if !error.is_null() {
132 let _: *mut c_void = msg_send_0(error, sel!(retain));
133 return Err(mtl_foundation::Error::from_ptr(error)
134 .expect("error pointer should be valid"));
135 }
136 return Err(mtl_foundation::Error::error(
137 std::ptr::null_mut(),
138 -1,
139 std::ptr::null_mut(),
140 )
141 .expect("failed to create error object"));
142 }
143
144 Ok(RenderPipelineState::from_raw(ptr).expect("render pipeline state should be valid"))
145 }
146 }
147
148 pub fn new_compute_pipeline_state_with_function(
156 &self,
157 function: &Function,
158 ) -> Result<ComputePipelineState, mtl_foundation::Error> {
159 let mut error: *mut c_void = std::ptr::null_mut();
160 unsafe {
161 let ptr: *mut c_void = msg_send_2(
162 self.as_ptr(),
163 sel!(newComputePipelineStateWithFunction: error:),
164 function.as_ptr(),
165 &mut error as *mut _,
166 );
167
168 if ptr.is_null() {
169 if !error.is_null() {
170 let _: *mut c_void = msg_send_0(error, sel!(retain));
171 return Err(mtl_foundation::Error::from_ptr(error)
172 .expect("error pointer should be valid"));
173 }
174 return Err(mtl_foundation::Error::error(
175 std::ptr::null_mut(),
176 -1,
177 std::ptr::null_mut(),
178 )
179 .expect("failed to create error object"));
180 }
181
182 Ok(
183 ComputePipelineState::from_raw(ptr)
184 .expect("compute pipeline state should be valid"),
185 )
186 }
187 }
188
189 pub unsafe fn new_compute_pipeline_state_with_function_and_reflection(
197 &self,
198 function: &Function,
199 options: crate::enums::PipelineOption,
200 reflection: *mut *mut c_void,
201 ) -> Result<ComputePipelineState, mtl_foundation::Error> {
202 let mut error: *mut c_void = std::ptr::null_mut();
203 unsafe {
204 let ptr: *mut c_void = mtl_sys::msg_send_4(
205 self.as_ptr(),
206 sel!(newComputePipelineStateWithFunction: options: reflection: error:),
207 function.as_ptr(),
208 options,
209 reflection,
210 &mut error as *mut _,
211 );
212
213 if ptr.is_null() {
214 if !error.is_null() {
215 let _: *mut c_void = msg_send_0(error, sel!(retain));
216 return Err(mtl_foundation::Error::from_ptr(error)
217 .expect("error pointer should be valid"));
218 }
219 return Err(mtl_foundation::Error::error(
220 std::ptr::null_mut(),
221 -1,
222 std::ptr::null_mut(),
223 )
224 .expect("failed to create error object"));
225 }
226
227 Ok(
228 ComputePipelineState::from_raw(ptr)
229 .expect("compute pipeline state should be valid"),
230 )
231 }
232 }
233
234 pub fn new_compute_pipeline_state_validated(
255 &self,
256 descriptor: &ComputePipelineDescriptor,
257 ) -> Result<ComputePipelineState, ValidationError> {
258 if descriptor.compute_function().is_none() {
260 return Err(ValidationError::MissingComputeFunction);
261 }
262
263 unsafe {
265 self.new_compute_pipeline_state_with_descriptor(
266 descriptor.as_ptr(),
267 crate::enums::PipelineOption::NONE,
268 std::ptr::null_mut(),
269 )
270 .map_err(ValidationError::from)
271 }
272 }
273
274 pub unsafe fn new_compute_pipeline_state_with_descriptor(
282 &self,
283 descriptor: *const c_void,
284 options: crate::enums::PipelineOption,
285 reflection: *mut *mut c_void,
286 ) -> Result<ComputePipelineState, mtl_foundation::Error> {
287 let mut error: *mut c_void = std::ptr::null_mut();
288 unsafe {
289 let ptr: *mut c_void = mtl_sys::msg_send_4(
290 self.as_ptr(),
291 sel!(newComputePipelineStateWithDescriptor: options: reflection: error:),
292 descriptor,
293 options,
294 reflection,
295 &mut error as *mut _,
296 );
297
298 if ptr.is_null() {
299 if !error.is_null() {
300 let _: *mut c_void = msg_send_0(error, sel!(retain));
301 return Err(mtl_foundation::Error::from_ptr(error)
302 .expect("error pointer should be valid"));
303 }
304 return Err(mtl_foundation::Error::error(
305 std::ptr::null_mut(),
306 -1,
307 std::ptr::null_mut(),
308 )
309 .expect("failed to create error object"));
310 }
311
312 Ok(
313 ComputePipelineState::from_raw(ptr)
314 .expect("compute pipeline state should be valid"),
315 )
316 }
317 }
318
319 pub fn new_render_pipeline_state_async<F>(
329 &self,
330 descriptor: &RenderPipelineDescriptor,
331 completion_handler: F,
332 ) where
333 F: Fn(Option<RenderPipelineState>, Option<mtl_foundation::Error>) + Send + 'static,
334 {
335 let block =
336 mtl_sys::TwoArgBlock::from_fn(move |state_ptr: *mut c_void, err_ptr: *mut c_void| {
337 let state = if state_ptr.is_null() {
338 None
339 } else {
340 unsafe { RenderPipelineState::from_raw(state_ptr) }
341 };
342
343 let error = if err_ptr.is_null() {
344 None
345 } else {
346 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
347 };
348
349 completion_handler(state, error);
350 });
351
352 unsafe {
353 msg_send_2::<(), *const c_void, *const c_void>(
354 self.as_ptr(),
355 sel!(newRenderPipelineStateWithDescriptor:completionHandler:),
356 descriptor.as_ptr(),
357 block.as_ptr(),
358 );
359 }
360
361 std::mem::forget(block);
362 }
363
364 pub fn new_render_pipeline_state_with_reflection_async<F>(
370 &self,
371 descriptor: &RenderPipelineDescriptor,
372 options: crate::enums::PipelineOption,
373 completion_handler: F,
374 ) where
375 F: Fn(
376 Option<RenderPipelineState>,
377 Option<RenderPipelineReflection>,
378 Option<mtl_foundation::Error>,
379 ) + Send
380 + 'static,
381 {
382 let block = mtl_sys::ThreeArgBlock::from_fn(
383 move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
384 let state = if state_ptr.is_null() {
385 None
386 } else {
387 unsafe { RenderPipelineState::from_raw(state_ptr) }
388 };
389
390 let reflection = if reflection_ptr.is_null() {
391 None
392 } else {
393 unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
394 };
395
396 let error = if err_ptr.is_null() {
397 None
398 } else {
399 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
400 };
401
402 completion_handler(state, reflection, error);
403 },
404 );
405
406 unsafe {
407 msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
408 self.as_ptr(),
409 sel!(newRenderPipelineStateWithDescriptor:options:completionHandler:),
410 descriptor.as_ptr(),
411 options,
412 block.as_ptr(),
413 );
414 }
415
416 std::mem::forget(block);
417 }
418
419 pub fn new_tile_render_pipeline_state_with_reflection_async<F>(
423 &self,
424 descriptor: &TileRenderPipelineDescriptor,
425 options: crate::enums::PipelineOption,
426 completion_handler: F,
427 ) where
428 F: Fn(
429 Option<RenderPipelineState>,
430 Option<RenderPipelineReflection>,
431 Option<mtl_foundation::Error>,
432 ) + Send
433 + 'static,
434 {
435 let block = mtl_sys::ThreeArgBlock::from_fn(
436 move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
437 let state = if state_ptr.is_null() {
438 None
439 } else {
440 unsafe { RenderPipelineState::from_raw(state_ptr) }
441 };
442
443 let reflection = if reflection_ptr.is_null() {
444 None
445 } else {
446 unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
447 };
448
449 let error = if err_ptr.is_null() {
450 None
451 } else {
452 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
453 };
454
455 completion_handler(state, reflection, error);
456 },
457 );
458
459 unsafe {
460 msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
461 self.as_ptr(),
462 sel!(newRenderPipelineStateWithTileDescriptor:options:completionHandler:),
463 descriptor.as_ptr(),
464 options,
465 block.as_ptr(),
466 );
467 }
468
469 std::mem::forget(block);
470 }
471
472 pub fn new_mesh_render_pipeline_state_with_reflection_async<F>(
476 &self,
477 descriptor: &MeshRenderPipelineDescriptor,
478 options: crate::enums::PipelineOption,
479 completion_handler: F,
480 ) where
481 F: Fn(
482 Option<RenderPipelineState>,
483 Option<RenderPipelineReflection>,
484 Option<mtl_foundation::Error>,
485 ) + Send
486 + 'static,
487 {
488 let block = mtl_sys::ThreeArgBlock::from_fn(
489 move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
490 let state = if state_ptr.is_null() {
491 None
492 } else {
493 unsafe { RenderPipelineState::from_raw(state_ptr) }
494 };
495
496 let reflection = if reflection_ptr.is_null() {
497 None
498 } else {
499 unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
500 };
501
502 let error = if err_ptr.is_null() {
503 None
504 } else {
505 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
506 };
507
508 completion_handler(state, reflection, error);
509 },
510 );
511
512 unsafe {
513 msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
514 self.as_ptr(),
515 sel!(newRenderPipelineStateWithMeshDescriptor:options:completionHandler:),
516 descriptor.as_ptr(),
517 options,
518 block.as_ptr(),
519 );
520 }
521
522 std::mem::forget(block);
523 }
524
525 pub fn new_compute_pipeline_state_with_function_async<F>(
533 &self,
534 function: &Function,
535 completion_handler: F,
536 ) where
537 F: Fn(Option<ComputePipelineState>, Option<mtl_foundation::Error>) + Send + 'static,
538 {
539 let block =
540 mtl_sys::TwoArgBlock::from_fn(move |state_ptr: *mut c_void, err_ptr: *mut c_void| {
541 let state = if state_ptr.is_null() {
542 None
543 } else {
544 unsafe { ComputePipelineState::from_raw(state_ptr) }
545 };
546
547 let error = if err_ptr.is_null() {
548 None
549 } else {
550 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
551 };
552
553 completion_handler(state, error);
554 });
555
556 unsafe {
557 msg_send_2::<(), *const c_void, *const c_void>(
558 self.as_ptr(),
559 sel!(newComputePipelineStateWithFunction:completionHandler:),
560 function.as_ptr(),
561 block.as_ptr(),
562 );
563 }
564
565 std::mem::forget(block);
566 }
567
568 pub fn new_compute_pipeline_state_with_function_and_reflection_async<F>(
572 &self,
573 function: &Function,
574 options: crate::enums::PipelineOption,
575 completion_handler: F,
576 ) where
577 F: Fn(
578 Option<ComputePipelineState>,
579 Option<ComputePipelineReflection>,
580 Option<mtl_foundation::Error>,
581 ) + Send
582 + 'static,
583 {
584 let block = mtl_sys::ThreeArgBlock::from_fn(
585 move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
586 let state = if state_ptr.is_null() {
587 None
588 } else {
589 unsafe { ComputePipelineState::from_raw(state_ptr) }
590 };
591
592 let reflection = if reflection_ptr.is_null() {
593 None
594 } else {
595 unsafe { ComputePipelineReflection::from_raw(reflection_ptr) }
596 };
597
598 let error = if err_ptr.is_null() {
599 None
600 } else {
601 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
602 };
603
604 completion_handler(state, reflection, error);
605 },
606 );
607
608 unsafe {
609 msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
610 self.as_ptr(),
611 sel!(newComputePipelineStateWithFunction:options:completionHandler:),
612 function.as_ptr(),
613 options,
614 block.as_ptr(),
615 );
616 }
617
618 std::mem::forget(block);
619 }
620
621 pub fn new_compute_pipeline_state_with_descriptor_async<F>(
625 &self,
626 descriptor: &ComputePipelineDescriptor,
627 options: crate::enums::PipelineOption,
628 completion_handler: F,
629 ) where
630 F: Fn(
631 Option<ComputePipelineState>,
632 Option<ComputePipelineReflection>,
633 Option<mtl_foundation::Error>,
634 ) + Send
635 + 'static,
636 {
637 let block = mtl_sys::ThreeArgBlock::from_fn(
638 move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
639 let state = if state_ptr.is_null() {
640 None
641 } else {
642 unsafe { ComputePipelineState::from_raw(state_ptr) }
643 };
644
645 let reflection = if reflection_ptr.is_null() {
646 None
647 } else {
648 unsafe { ComputePipelineReflection::from_raw(reflection_ptr) }
649 };
650
651 let error = if err_ptr.is_null() {
652 None
653 } else {
654 unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
655 };
656
657 completion_handler(state, reflection, error);
658 },
659 );
660
661 unsafe {
662 msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
663 self.as_ptr(),
664 sel!(newComputePipelineStateWithDescriptor:options:completionHandler:),
665 descriptor.as_ptr(),
666 options,
667 block.as_ptr(),
668 );
669 }
670
671 std::mem::forget(block);
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use crate::device::system_default;
678
679 #[test]
680 fn test_new_compute_pipeline_state() {
681 let device = system_default().expect("no Metal device");
682
683 let source = r#"
684 #include <metal_stdlib>
685 using namespace metal;
686
687 kernel void test_kernel(device float* data [[buffer(0)]],
688 uint id [[thread_position_in_grid]]) {
689 data[id] = data[id] * 2.0;
690 }
691 "#;
692
693 let library = device
694 .new_library_with_source(source, None)
695 .expect("failed to compile shader");
696
697 let function = library
698 .new_function_with_name("test_kernel")
699 .expect("function not found");
700
701 let pipeline = device.new_compute_pipeline_state_with_function(&function);
702 assert!(pipeline.is_ok());
703
704 let pipeline = pipeline.unwrap();
705 assert!(pipeline.max_total_threads_per_threadgroup() > 0);
706 assert!(pipeline.thread_execution_width() > 0);
707 }
708}