Skip to main content

mtl_gpu/device/
pipeline.rs

1//! Device pipeline state creation methods.
2//!
3//! Corresponds to pipeline state creation methods in `Metal/MTLDevice.hpp`.
4
5use std::ffi::c_void;
6
7use mtl_foundation::Referencing;
8use mtl_sys::{msg_send_0, msg_send_2, msg_send_3, sel};
9
10use super::Device;
11use crate::error::ValidationError;
12use crate::library::Function;
13use crate::pipeline::{
14    ComputePipelineDescriptor, ComputePipelineReflection, ComputePipelineState,
15    MeshRenderPipelineDescriptor, RenderPipelineDescriptor, RenderPipelineReflection,
16    RenderPipelineState, TileRenderPipelineDescriptor,
17};
18
19impl Device {
20    // =========================================================================
21    // Render Pipeline State Creation
22    // =========================================================================
23
24    /// Create a render pipeline state from a descriptor.
25    ///
26    /// C++ equivalent: `RenderPipelineState* newRenderPipelineState(const RenderPipelineDescriptor*, NS::Error**)`
27    ///
28    /// # Safety
29    ///
30    /// The descriptor pointer must be valid.
31    pub unsafe fn new_render_pipeline_state(
32        &self,
33        descriptor: *const c_void,
34    ) -> Result<RenderPipelineState, mtl_foundation::Error> {
35        let mut error: *mut c_void = std::ptr::null_mut();
36        unsafe {
37            let ptr: *mut c_void = msg_send_2(
38                self.as_ptr(),
39                sel!(newRenderPipelineStateWithDescriptor: error:),
40                descriptor,
41                &mut error as *mut _,
42            );
43
44            if ptr.is_null() {
45                if !error.is_null() {
46                    let _: *mut c_void = msg_send_0(error, sel!(retain));
47                    return Err(mtl_foundation::Error::from_ptr(error)
48                        .expect("error pointer should be valid"));
49                }
50                return Err(mtl_foundation::Error::error(
51                    std::ptr::null_mut(),
52                    -1,
53                    std::ptr::null_mut(),
54                )
55                .expect("failed to create error object"));
56            }
57
58            Ok(RenderPipelineState::from_raw(ptr).expect("render pipeline state should be valid"))
59        }
60    }
61
62    /// Create a render pipeline state with validation.
63    ///
64    /// This safe method validates the descriptor before calling Metal APIs:
65    /// - Ensures a vertex function is set (required)
66    /// - Validates raster sample count is supported by the device
67    ///
68    /// Use this method instead of `new_render_pipeline_state` to avoid process
69    /// aborts from Metal's validation layer.
70    ///
71    /// # Example
72    ///
73    /// ```ignore
74    /// let desc = RenderPipelineDescriptor::new().unwrap();
75    /// desc.set_vertex_function(Some(&vertex_fn));
76    /// desc.set_fragment_function(Some(&fragment_fn));
77    ///
78    /// match device.new_render_pipeline_state_with_descriptor(&desc) {
79    ///     Ok(pipeline) => { /* use pipeline */ }
80    ///     Err(ValidationError::MissingVertexFunction) => { /* handle error */ }
81    ///     Err(e) => { /* handle other errors */ }
82    /// }
83    /// ```
84    pub fn new_render_pipeline_state_with_descriptor(
85        &self,
86        descriptor: &RenderPipelineDescriptor,
87    ) -> Result<RenderPipelineState, ValidationError> {
88        // Validate vertex function is set
89        if descriptor.vertex_function().is_none() {
90            return Err(ValidationError::MissingVertexFunction);
91        }
92
93        // Validate raster sample count if > 1
94        let sample_count = descriptor.raster_sample_count();
95        if sample_count > 1 && !self.supports_texture_sample_count(sample_count) {
96            return Err(ValidationError::UnsupportedRasterSampleCount(sample_count));
97        }
98
99        // Call unsafe implementation
100        unsafe {
101            self.new_render_pipeline_state(descriptor.as_ptr())
102                .map_err(ValidationError::from)
103        }
104    }
105
106    /// Create a render pipeline state with options.
107    ///
108    /// C++ equivalent: `RenderPipelineState* newRenderPipelineState(const RenderPipelineDescriptor*, PipelineOption, RenderPipelineReflection**, NS::Error**)`
109    ///
110    /// # Safety
111    ///
112    /// The descriptor and reflection pointers must be valid.
113    pub unsafe fn new_render_pipeline_state_with_reflection(
114        &self,
115        descriptor: *const c_void,
116        options: crate::enums::PipelineOption,
117        reflection: *mut *mut c_void,
118    ) -> Result<RenderPipelineState, mtl_foundation::Error> {
119        let mut error: *mut c_void = std::ptr::null_mut();
120        unsafe {
121            let ptr: *mut c_void = mtl_sys::msg_send_4(
122                self.as_ptr(),
123                sel!(newRenderPipelineStateWithDescriptor: options: reflection: error:),
124                descriptor,
125                options,
126                reflection,
127                &mut error as *mut _,
128            );
129
130            if ptr.is_null() {
131                if !error.is_null() {
132                    let _: *mut c_void = msg_send_0(error, sel!(retain));
133                    return Err(mtl_foundation::Error::from_ptr(error)
134                        .expect("error pointer should be valid"));
135                }
136                return Err(mtl_foundation::Error::error(
137                    std::ptr::null_mut(),
138                    -1,
139                    std::ptr::null_mut(),
140                )
141                .expect("failed to create error object"));
142            }
143
144            Ok(RenderPipelineState::from_raw(ptr).expect("render pipeline state should be valid"))
145        }
146    }
147
148    // =========================================================================
149    // Compute Pipeline State Creation
150    // =========================================================================
151
152    /// Create a compute pipeline state from a function.
153    ///
154    /// C++ equivalent: `ComputePipelineState* newComputePipelineState(const Function*, NS::Error**)`
155    pub fn new_compute_pipeline_state_with_function(
156        &self,
157        function: &Function,
158    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
159        let mut error: *mut c_void = std::ptr::null_mut();
160        unsafe {
161            let ptr: *mut c_void = msg_send_2(
162                self.as_ptr(),
163                sel!(newComputePipelineStateWithFunction: error:),
164                function.as_ptr(),
165                &mut error as *mut _,
166            );
167
168            if ptr.is_null() {
169                if !error.is_null() {
170                    let _: *mut c_void = msg_send_0(error, sel!(retain));
171                    return Err(mtl_foundation::Error::from_ptr(error)
172                        .expect("error pointer should be valid"));
173                }
174                return Err(mtl_foundation::Error::error(
175                    std::ptr::null_mut(),
176                    -1,
177                    std::ptr::null_mut(),
178                )
179                .expect("failed to create error object"));
180            }
181
182            Ok(
183                ComputePipelineState::from_raw(ptr)
184                    .expect("compute pipeline state should be valid"),
185            )
186        }
187    }
188
189    /// Create a compute pipeline state with options.
190    ///
191    /// C++ equivalent: `ComputePipelineState* newComputePipelineState(const Function*, PipelineOption, ComputePipelineReflection**, NS::Error**)`
192    ///
193    /// # Safety
194    ///
195    /// The reflection pointer must be valid if not null.
196    pub unsafe fn new_compute_pipeline_state_with_function_and_reflection(
197        &self,
198        function: &Function,
199        options: crate::enums::PipelineOption,
200        reflection: *mut *mut c_void,
201    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
202        let mut error: *mut c_void = std::ptr::null_mut();
203        unsafe {
204            let ptr: *mut c_void = mtl_sys::msg_send_4(
205                self.as_ptr(),
206                sel!(newComputePipelineStateWithFunction: options: reflection: error:),
207                function.as_ptr(),
208                options,
209                reflection,
210                &mut error as *mut _,
211            );
212
213            if ptr.is_null() {
214                if !error.is_null() {
215                    let _: *mut c_void = msg_send_0(error, sel!(retain));
216                    return Err(mtl_foundation::Error::from_ptr(error)
217                        .expect("error pointer should be valid"));
218                }
219                return Err(mtl_foundation::Error::error(
220                    std::ptr::null_mut(),
221                    -1,
222                    std::ptr::null_mut(),
223                )
224                .expect("failed to create error object"));
225            }
226
227            Ok(
228                ComputePipelineState::from_raw(ptr)
229                    .expect("compute pipeline state should be valid"),
230            )
231        }
232    }
233
234    /// Create a compute pipeline state with validation.
235    ///
236    /// This safe method validates the descriptor before calling Metal APIs:
237    /// - Ensures a compute function is set (required)
238    ///
239    /// Use this method instead of the unsafe `new_compute_pipeline_state_with_descriptor`
240    /// to avoid process aborts from Metal's validation layer.
241    ///
242    /// # Example
243    ///
244    /// ```ignore
245    /// let desc = ComputePipelineDescriptor::new().unwrap();
246    /// desc.set_compute_function(Some(&compute_fn));
247    ///
248    /// match device.new_compute_pipeline_state_validated(&desc) {
249    ///     Ok(pipeline) => { /* use pipeline */ }
250    ///     Err(ValidationError::MissingComputeFunction) => { /* handle error */ }
251    ///     Err(e) => { /* handle other errors */ }
252    /// }
253    /// ```
254    pub fn new_compute_pipeline_state_validated(
255        &self,
256        descriptor: &ComputePipelineDescriptor,
257    ) -> Result<ComputePipelineState, ValidationError> {
258        // Validate compute function is set
259        if descriptor.compute_function().is_none() {
260            return Err(ValidationError::MissingComputeFunction);
261        }
262
263        // Call the unsafe implementation with default options and no reflection
264        unsafe {
265            self.new_compute_pipeline_state_with_descriptor(
266                descriptor.as_ptr(),
267                crate::enums::PipelineOption::NONE,
268                std::ptr::null_mut(),
269            )
270            .map_err(ValidationError::from)
271        }
272    }
273
274    /// Create a compute pipeline state from a descriptor.
275    ///
276    /// C++ equivalent: `ComputePipelineState* newComputePipelineState(const ComputePipelineDescriptor*, PipelineOption, ComputePipelineReflection**, NS::Error**)`
277    ///
278    /// # Safety
279    ///
280    /// The descriptor and reflection pointers must be valid.
281    pub unsafe fn new_compute_pipeline_state_with_descriptor(
282        &self,
283        descriptor: *const c_void,
284        options: crate::enums::PipelineOption,
285        reflection: *mut *mut c_void,
286    ) -> Result<ComputePipelineState, mtl_foundation::Error> {
287        let mut error: *mut c_void = std::ptr::null_mut();
288        unsafe {
289            let ptr: *mut c_void = mtl_sys::msg_send_4(
290                self.as_ptr(),
291                sel!(newComputePipelineStateWithDescriptor: options: reflection: error:),
292                descriptor,
293                options,
294                reflection,
295                &mut error as *mut _,
296            );
297
298            if ptr.is_null() {
299                if !error.is_null() {
300                    let _: *mut c_void = msg_send_0(error, sel!(retain));
301                    return Err(mtl_foundation::Error::from_ptr(error)
302                        .expect("error pointer should be valid"));
303                }
304                return Err(mtl_foundation::Error::error(
305                    std::ptr::null_mut(),
306                    -1,
307                    std::ptr::null_mut(),
308                )
309                .expect("failed to create error object"));
310            }
311
312            Ok(
313                ComputePipelineState::from_raw(ptr)
314                    .expect("compute pipeline state should be valid"),
315            )
316        }
317    }
318
319    // =========================================================================
320    // Async Render Pipeline State Creation
321    // =========================================================================
322
323    /// Create a render pipeline state asynchronously.
324    ///
325    /// C++ equivalent: `void newRenderPipelineState(const RenderPipelineDescriptor*, NewRenderPipelineStateCompletionHandler)`
326    ///
327    /// The completion handler is called with the pipeline state and any error that occurred.
328    pub fn new_render_pipeline_state_async<F>(
329        &self,
330        descriptor: &RenderPipelineDescriptor,
331        completion_handler: F,
332    ) where
333        F: Fn(Option<RenderPipelineState>, Option<mtl_foundation::Error>) + Send + 'static,
334    {
335        let block =
336            mtl_sys::TwoArgBlock::from_fn(move |state_ptr: *mut c_void, err_ptr: *mut c_void| {
337                let state = if state_ptr.is_null() {
338                    None
339                } else {
340                    unsafe { RenderPipelineState::from_raw(state_ptr) }
341                };
342
343                let error = if err_ptr.is_null() {
344                    None
345                } else {
346                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
347                };
348
349                completion_handler(state, error);
350            });
351
352        unsafe {
353            msg_send_2::<(), *const c_void, *const c_void>(
354                self.as_ptr(),
355                sel!(newRenderPipelineStateWithDescriptor:completionHandler:),
356                descriptor.as_ptr(),
357                block.as_ptr(),
358            );
359        }
360
361        std::mem::forget(block);
362    }
363
364    /// Create a render pipeline state with reflection asynchronously.
365    ///
366    /// C++ equivalent: `void newRenderPipelineState(const RenderPipelineDescriptor*, PipelineOption, NewRenderPipelineStateWithReflectionCompletionHandler)`
367    ///
368    /// The completion handler is called with the pipeline state, reflection data, and any error.
369    pub fn new_render_pipeline_state_with_reflection_async<F>(
370        &self,
371        descriptor: &RenderPipelineDescriptor,
372        options: crate::enums::PipelineOption,
373        completion_handler: F,
374    ) where
375        F: Fn(
376                Option<RenderPipelineState>,
377                Option<RenderPipelineReflection>,
378                Option<mtl_foundation::Error>,
379            ) + Send
380            + 'static,
381    {
382        let block = mtl_sys::ThreeArgBlock::from_fn(
383            move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
384                let state = if state_ptr.is_null() {
385                    None
386                } else {
387                    unsafe { RenderPipelineState::from_raw(state_ptr) }
388                };
389
390                let reflection = if reflection_ptr.is_null() {
391                    None
392                } else {
393                    unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
394                };
395
396                let error = if err_ptr.is_null() {
397                    None
398                } else {
399                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
400                };
401
402                completion_handler(state, reflection, error);
403            },
404        );
405
406        unsafe {
407            msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
408                self.as_ptr(),
409                sel!(newRenderPipelineStateWithDescriptor:options:completionHandler:),
410                descriptor.as_ptr(),
411                options,
412                block.as_ptr(),
413            );
414        }
415
416        std::mem::forget(block);
417    }
418
419    /// Create a tile render pipeline state with reflection asynchronously.
420    ///
421    /// C++ equivalent: `void newRenderPipelineState(const TileRenderPipelineDescriptor*, PipelineOption, NewRenderPipelineStateWithReflectionCompletionHandler)`
422    pub fn new_tile_render_pipeline_state_with_reflection_async<F>(
423        &self,
424        descriptor: &TileRenderPipelineDescriptor,
425        options: crate::enums::PipelineOption,
426        completion_handler: F,
427    ) where
428        F: Fn(
429                Option<RenderPipelineState>,
430                Option<RenderPipelineReflection>,
431                Option<mtl_foundation::Error>,
432            ) + Send
433            + 'static,
434    {
435        let block = mtl_sys::ThreeArgBlock::from_fn(
436            move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
437                let state = if state_ptr.is_null() {
438                    None
439                } else {
440                    unsafe { RenderPipelineState::from_raw(state_ptr) }
441                };
442
443                let reflection = if reflection_ptr.is_null() {
444                    None
445                } else {
446                    unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
447                };
448
449                let error = if err_ptr.is_null() {
450                    None
451                } else {
452                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
453                };
454
455                completion_handler(state, reflection, error);
456            },
457        );
458
459        unsafe {
460            msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
461                self.as_ptr(),
462                sel!(newRenderPipelineStateWithTileDescriptor:options:completionHandler:),
463                descriptor.as_ptr(),
464                options,
465                block.as_ptr(),
466            );
467        }
468
469        std::mem::forget(block);
470    }
471
472    /// Create a mesh render pipeline state with reflection asynchronously.
473    ///
474    /// C++ equivalent: `void newRenderPipelineState(const MeshRenderPipelineDescriptor*, PipelineOption, NewRenderPipelineStateWithReflectionCompletionHandler)`
475    pub fn new_mesh_render_pipeline_state_with_reflection_async<F>(
476        &self,
477        descriptor: &MeshRenderPipelineDescriptor,
478        options: crate::enums::PipelineOption,
479        completion_handler: F,
480    ) where
481        F: Fn(
482                Option<RenderPipelineState>,
483                Option<RenderPipelineReflection>,
484                Option<mtl_foundation::Error>,
485            ) + Send
486            + 'static,
487    {
488        let block = mtl_sys::ThreeArgBlock::from_fn(
489            move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
490                let state = if state_ptr.is_null() {
491                    None
492                } else {
493                    unsafe { RenderPipelineState::from_raw(state_ptr) }
494                };
495
496                let reflection = if reflection_ptr.is_null() {
497                    None
498                } else {
499                    unsafe { RenderPipelineReflection::from_raw(reflection_ptr) }
500                };
501
502                let error = if err_ptr.is_null() {
503                    None
504                } else {
505                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
506                };
507
508                completion_handler(state, reflection, error);
509            },
510        );
511
512        unsafe {
513            msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
514                self.as_ptr(),
515                sel!(newRenderPipelineStateWithMeshDescriptor:options:completionHandler:),
516                descriptor.as_ptr(),
517                options,
518                block.as_ptr(),
519            );
520        }
521
522        std::mem::forget(block);
523    }
524
525    // =========================================================================
526    // Async Compute Pipeline State Creation
527    // =========================================================================
528
529    /// Create a compute pipeline state from a function asynchronously.
530    ///
531    /// C++ equivalent: `void newComputePipelineState(const Function*, NewComputePipelineStateCompletionHandler)`
532    pub fn new_compute_pipeline_state_with_function_async<F>(
533        &self,
534        function: &Function,
535        completion_handler: F,
536    ) where
537        F: Fn(Option<ComputePipelineState>, Option<mtl_foundation::Error>) + Send + 'static,
538    {
539        let block =
540            mtl_sys::TwoArgBlock::from_fn(move |state_ptr: *mut c_void, err_ptr: *mut c_void| {
541                let state = if state_ptr.is_null() {
542                    None
543                } else {
544                    unsafe { ComputePipelineState::from_raw(state_ptr) }
545                };
546
547                let error = if err_ptr.is_null() {
548                    None
549                } else {
550                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
551                };
552
553                completion_handler(state, error);
554            });
555
556        unsafe {
557            msg_send_2::<(), *const c_void, *const c_void>(
558                self.as_ptr(),
559                sel!(newComputePipelineStateWithFunction:completionHandler:),
560                function.as_ptr(),
561                block.as_ptr(),
562            );
563        }
564
565        std::mem::forget(block);
566    }
567
568    /// Create a compute pipeline state with reflection asynchronously.
569    ///
570    /// C++ equivalent: `void newComputePipelineState(const Function*, PipelineOption, NewComputePipelineStateWithReflectionCompletionHandler)`
571    pub fn new_compute_pipeline_state_with_function_and_reflection_async<F>(
572        &self,
573        function: &Function,
574        options: crate::enums::PipelineOption,
575        completion_handler: F,
576    ) where
577        F: Fn(
578                Option<ComputePipelineState>,
579                Option<ComputePipelineReflection>,
580                Option<mtl_foundation::Error>,
581            ) + Send
582            + 'static,
583    {
584        let block = mtl_sys::ThreeArgBlock::from_fn(
585            move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
586                let state = if state_ptr.is_null() {
587                    None
588                } else {
589                    unsafe { ComputePipelineState::from_raw(state_ptr) }
590                };
591
592                let reflection = if reflection_ptr.is_null() {
593                    None
594                } else {
595                    unsafe { ComputePipelineReflection::from_raw(reflection_ptr) }
596                };
597
598                let error = if err_ptr.is_null() {
599                    None
600                } else {
601                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
602                };
603
604                completion_handler(state, reflection, error);
605            },
606        );
607
608        unsafe {
609            msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
610                self.as_ptr(),
611                sel!(newComputePipelineStateWithFunction:options:completionHandler:),
612                function.as_ptr(),
613                options,
614                block.as_ptr(),
615            );
616        }
617
618        std::mem::forget(block);
619    }
620
621    /// Create a compute pipeline state from a descriptor asynchronously.
622    ///
623    /// C++ equivalent: `void newComputePipelineState(const ComputePipelineDescriptor*, PipelineOption, NewComputePipelineStateWithReflectionCompletionHandler)`
624    pub fn new_compute_pipeline_state_with_descriptor_async<F>(
625        &self,
626        descriptor: &ComputePipelineDescriptor,
627        options: crate::enums::PipelineOption,
628        completion_handler: F,
629    ) where
630        F: Fn(
631                Option<ComputePipelineState>,
632                Option<ComputePipelineReflection>,
633                Option<mtl_foundation::Error>,
634            ) + Send
635            + 'static,
636    {
637        let block = mtl_sys::ThreeArgBlock::from_fn(
638            move |state_ptr: *mut c_void, reflection_ptr: *mut c_void, err_ptr: *mut c_void| {
639                let state = if state_ptr.is_null() {
640                    None
641                } else {
642                    unsafe { ComputePipelineState::from_raw(state_ptr) }
643                };
644
645                let reflection = if reflection_ptr.is_null() {
646                    None
647                } else {
648                    unsafe { ComputePipelineReflection::from_raw(reflection_ptr) }
649                };
650
651                let error = if err_ptr.is_null() {
652                    None
653                } else {
654                    unsafe { mtl_foundation::Error::from_ptr(err_ptr) }
655                };
656
657                completion_handler(state, reflection, error);
658            },
659        );
660
661        unsafe {
662            msg_send_3::<(), *const c_void, crate::enums::PipelineOption, *const c_void>(
663                self.as_ptr(),
664                sel!(newComputePipelineStateWithDescriptor:options:completionHandler:),
665                descriptor.as_ptr(),
666                options,
667                block.as_ptr(),
668            );
669        }
670
671        std::mem::forget(block);
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use crate::device::system_default;
678
679    #[test]
680    fn test_new_compute_pipeline_state() {
681        let device = system_default().expect("no Metal device");
682
683        let source = r#"
684            #include <metal_stdlib>
685            using namespace metal;
686
687            kernel void test_kernel(device float* data [[buffer(0)]],
688                                   uint id [[thread_position_in_grid]]) {
689                data[id] = data[id] * 2.0;
690            }
691        "#;
692
693        let library = device
694            .new_library_with_source(source, None)
695            .expect("failed to compile shader");
696
697        let function = library
698            .new_function_with_name("test_kernel")
699            .expect("function not found");
700
701        let pipeline = device.new_compute_pipeline_state_with_function(&function);
702        assert!(pipeline.is_ok());
703
704        let pipeline = pipeline.unwrap();
705        assert!(pipeline.max_total_threads_per_threadgroup() > 0);
706        assert!(pipeline.thread_execution_width() > 0);
707    }
708}