Skip to main content

mtl_gpu/mtl4/
pipeline_state.rs

1//! MTL4 PipelineState implementation.
2//!
3//! Corresponds to `Metal/MTL4PipelineState.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11use super::enums::ShaderReflection;
12use crate::ShaderValidation;
13
14// ============================================================
15// PipelineOptions
16// ============================================================
17
18/// Options for MTL4 pipeline creation.
19///
20/// C++ equivalent: `MTL4::PipelineOptions`
21///
22/// PipelineOptions controls shader reflection and validation
23/// settings for pipeline compilation.
24#[repr(transparent)]
25pub struct PipelineOptions(NonNull<c_void>);
26
27impl PipelineOptions {
28    /// Create a PipelineOptions from a raw pointer.
29    #[inline]
30    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
31        NonNull::new(ptr).map(Self)
32    }
33
34    /// Get the raw pointer.
35    #[inline]
36    pub fn as_raw(&self) -> *mut c_void {
37        self.0.as_ptr()
38    }
39
40    /// Create new pipeline options.
41    pub fn new() -> Option<Self> {
42        unsafe {
43            let class = mtl_sys::Class::get("MTL4PipelineOptions")?;
44            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
45            if ptr.is_null() {
46                return None;
47            }
48            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
49            Self::from_raw(ptr)
50        }
51    }
52
53    /// Get the shader reflection options.
54    ///
55    /// C++ equivalent: `ShaderReflection shaderReflection() const`
56    pub fn shader_reflection(&self) -> ShaderReflection {
57        unsafe { msg_send_0(self.as_ptr(), sel!(shaderReflection)) }
58    }
59
60    /// Set the shader reflection options.
61    ///
62    /// C++ equivalent: `void setShaderReflection(MTL4::ShaderReflection)`
63    pub fn set_shader_reflection(&self, reflection: ShaderReflection) {
64        unsafe {
65            let _: () = msg_send_1(self.as_ptr(), sel!(setShaderReflection:), reflection);
66        }
67    }
68
69    /// Get the shader validation setting.
70    ///
71    /// C++ equivalent: `MTL::ShaderValidation shaderValidation() const`
72    pub fn shader_validation(&self) -> ShaderValidation {
73        unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
74    }
75
76    /// Set the shader validation setting.
77    ///
78    /// C++ equivalent: `void setShaderValidation(MTL::ShaderValidation)`
79    pub fn set_shader_validation(&self, validation: ShaderValidation) {
80        unsafe {
81            let _: () = msg_send_1(self.as_ptr(), sel!(setShaderValidation:), validation);
82        }
83    }
84}
85
86impl Clone for PipelineOptions {
87    fn clone(&self) -> Self {
88        unsafe {
89            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
90        }
91        Self(self.0)
92    }
93}
94
95impl Drop for PipelineOptions {
96    fn drop(&mut self) {
97        unsafe {
98            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
99        }
100    }
101}
102
103impl Referencing for PipelineOptions {
104    #[inline]
105    fn as_ptr(&self) -> *const c_void {
106        self.0.as_ptr()
107    }
108}
109
110unsafe impl Send for PipelineOptions {}
111unsafe impl Sync for PipelineOptions {}
112
113impl std::fmt::Debug for PipelineOptions {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("PipelineOptions")
116            .field("shader_reflection", &self.shader_reflection())
117            .finish()
118    }
119}
120
121// ============================================================
122// PipelineDescriptor
123// ============================================================
124
125/// Base descriptor for MTL4 pipelines.
126///
127/// C++ equivalent: `MTL4::PipelineDescriptor`
128///
129/// PipelineDescriptor is the base class for all MTL4 pipeline descriptors.
130#[repr(transparent)]
131pub struct PipelineDescriptor(NonNull<c_void>);
132
133impl PipelineDescriptor {
134    /// Create a PipelineDescriptor from a raw pointer.
135    #[inline]
136    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
137        NonNull::new(ptr).map(Self)
138    }
139
140    /// Get the raw pointer.
141    #[inline]
142    pub fn as_raw(&self) -> *mut c_void {
143        self.0.as_ptr()
144    }
145
146    /// Create a new pipeline descriptor.
147    pub fn new() -> Option<Self> {
148        unsafe {
149            let class = mtl_sys::Class::get("MTL4PipelineDescriptor")?;
150            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
151            if ptr.is_null() {
152                return None;
153            }
154            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
155            Self::from_raw(ptr)
156        }
157    }
158
159    /// Get the label.
160    ///
161    /// C++ equivalent: `NS::String* label() const`
162    pub fn label(&self) -> Option<String> {
163        unsafe {
164            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
165            if ns_string.is_null() {
166                return None;
167            }
168            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
169            if c_str.is_null() {
170                return None;
171            }
172            Some(
173                std::ffi::CStr::from_ptr(c_str)
174                    .to_string_lossy()
175                    .into_owned(),
176            )
177        }
178    }
179
180    /// Set the label.
181    ///
182    /// C++ equivalent: `void setLabel(const NS::String*)`
183    pub fn set_label(&self, label: &str) {
184        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
185            unsafe {
186                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
187            }
188        }
189    }
190
191    /// Get the pipeline options.
192    ///
193    /// C++ equivalent: `PipelineOptions* options() const`
194    pub fn options(&self) -> Option<PipelineOptions> {
195        unsafe {
196            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(options));
197            PipelineOptions::from_raw(ptr)
198        }
199    }
200
201    /// Set the pipeline options.
202    ///
203    /// C++ equivalent: `void setOptions(const MTL4::PipelineOptions*)`
204    pub fn set_options(&self, options: &PipelineOptions) {
205        unsafe {
206            let _: () = msg_send_1(self.as_ptr(), sel!(setOptions:), options.as_ptr());
207        }
208    }
209}
210
211impl Clone for PipelineDescriptor {
212    fn clone(&self) -> Self {
213        unsafe {
214            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
215        }
216        Self(self.0)
217    }
218}
219
220impl Drop for PipelineDescriptor {
221    fn drop(&mut self) {
222        unsafe {
223            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
224        }
225    }
226}
227
228impl Referencing for PipelineDescriptor {
229    #[inline]
230    fn as_ptr(&self) -> *const c_void {
231        self.0.as_ptr()
232    }
233}
234
235unsafe impl Send for PipelineDescriptor {}
236unsafe impl Sync for PipelineDescriptor {}
237
238impl std::fmt::Debug for PipelineDescriptor {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.debug_struct("PipelineDescriptor")
241            .field("label", &self.label())
242            .finish()
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_pipeline_options_size() {
252        assert_eq!(
253            std::mem::size_of::<PipelineOptions>(),
254            std::mem::size_of::<*mut c_void>()
255        );
256    }
257
258    #[test]
259    fn test_pipeline_descriptor_size() {
260        assert_eq!(
261            std::mem::size_of::<PipelineDescriptor>(),
262            std::mem::size_of::<*mut c_void>()
263        );
264    }
265}