Skip to main content

mtl_gpu/mtl4/
command_encoder.rs

1//! MTL4 CommandEncoder implementation.
2//!
3//! Corresponds to `Metal/MTL4CommandEncoder.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_3, sel};
10
11use super::enums::VisibilityOptions;
12use crate::Device;
13
14// ============================================================
15// CommandEncoder
16// ============================================================
17
18/// Base command encoder for MTL4.
19///
20/// C++ equivalent: `MTL4::CommandEncoder`
21///
22/// CommandEncoder provides the base interface for all MTL4 command encoders,
23/// including barrier operations, fence management, and debug groups.
24#[repr(transparent)]
25pub struct CommandEncoder(NonNull<c_void>);
26
27impl CommandEncoder {
28    /// Create a CommandEncoder from a raw pointer.
29    #[inline]
30    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
31        NonNull::new(ptr).map(Self)
32    }
33
34    /// Get the raw pointer.
35    #[inline]
36    pub fn as_raw(&self) -> *mut c_void {
37        self.0.as_ptr()
38    }
39
40    /// Get the device.
41    ///
42    /// C++ equivalent: `MTL::Device* device() const`
43    pub fn device(&self) -> Option<Device> {
44        unsafe {
45            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
46            Device::from_raw(ptr)
47        }
48    }
49
50    /// Get the label.
51    ///
52    /// C++ equivalent: `NS::String* label() const`
53    pub fn label(&self) -> Option<String> {
54        unsafe {
55            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
56            if ns_string.is_null() {
57                return None;
58            }
59            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
60            if c_str.is_null() {
61                return None;
62            }
63            Some(
64                std::ffi::CStr::from_ptr(c_str)
65                    .to_string_lossy()
66                    .into_owned(),
67            )
68        }
69    }
70
71    /// Set the label.
72    ///
73    /// C++ equivalent: `void setLabel(const NS::String*)`
74    pub fn set_label(&self, label: &str) {
75        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
76            unsafe {
77                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
78            }
79        }
80    }
81
82    // ========== Barrier Methods ==========
83
84    /// Insert a barrier for all resources.
85    ///
86    /// C++ equivalent: `void barrier()`
87    pub fn barrier(&self) {
88        unsafe {
89            let _: () = msg_send_0(self.as_ptr(), sel!(barrier));
90        }
91    }
92
93    /// Insert a barrier for a specific buffer.
94    ///
95    /// C++ equivalent: `void barrier(const MTL::Buffer*, MTL4::VisibilityOptions)`
96    pub fn barrier_buffer(&self, buffer: *const c_void, visibility: VisibilityOptions) {
97        unsafe {
98            let _: () = msg_send_2(
99                self.as_ptr(),
100                sel!(barrierWithBuffer:visibilityOptions:),
101                buffer,
102                visibility.0,
103            );
104        }
105    }
106
107    /// Insert a barrier for multiple buffers.
108    ///
109    /// C++ equivalent: `void barrier(const MTL::Buffer* const*, NS::UInteger, MTL4::VisibilityOptions)`
110    pub fn barrier_buffers(
111        &self,
112        buffers: *const *const c_void,
113        count: UInteger,
114        visibility: VisibilityOptions,
115    ) {
116        unsafe {
117            let _: () = msg_send_3(
118                self.as_ptr(),
119                sel!(barrierWithBuffers:count:visibilityOptions:),
120                buffers,
121                count,
122                visibility.0,
123            );
124        }
125    }
126
127    /// Insert a barrier for a specific texture.
128    ///
129    /// C++ equivalent: `void barrier(const MTL::Texture*, MTL4::VisibilityOptions)`
130    pub fn barrier_texture(&self, texture: *const c_void, visibility: VisibilityOptions) {
131        unsafe {
132            let _: () = msg_send_2(
133                self.as_ptr(),
134                sel!(barrierWithTexture:visibilityOptions:),
135                texture,
136                visibility.0,
137            );
138        }
139    }
140
141    /// Insert a barrier for multiple textures.
142    ///
143    /// C++ equivalent: `void barrier(const MTL::Texture* const*, NS::UInteger, MTL4::VisibilityOptions)`
144    pub fn barrier_textures(
145        &self,
146        textures: *const *const c_void,
147        count: UInteger,
148        visibility: VisibilityOptions,
149    ) {
150        unsafe {
151            let _: () = msg_send_3(
152                self.as_ptr(),
153                sel!(barrierWithTextures:count:visibilityOptions:),
154                textures,
155                count,
156                visibility.0,
157            );
158        }
159    }
160
161    // ========== Fence Methods ==========
162
163    /// Update a fence.
164    ///
165    /// C++ equivalent: `void updateFence(const MTL::Fence*)`
166    pub fn update_fence(&self, fence: *const c_void) {
167        unsafe {
168            let _: () = msg_send_1(self.as_ptr(), sel!(updateFence:), fence);
169        }
170    }
171
172    /// Wait for a fence.
173    ///
174    /// C++ equivalent: `void waitForFence(const MTL::Fence*)`
175    pub fn wait_for_fence(&self, fence: *const c_void) {
176        unsafe {
177            let _: () = msg_send_1(self.as_ptr(), sel!(waitForFence:), fence);
178        }
179    }
180
181    // ========== Debug Methods ==========
182
183    /// Push a debug group.
184    ///
185    /// C++ equivalent: `void pushDebugGroup(const NS::String*)`
186    pub fn push_debug_group(&self, name: &str) {
187        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
188            unsafe {
189                let _: () = msg_send_1(self.as_ptr(), sel!(pushDebugGroup:), ns_name.as_ptr());
190            }
191        }
192    }
193
194    /// Pop a debug group.
195    ///
196    /// C++ equivalent: `void popDebugGroup()`
197    pub fn pop_debug_group(&self) {
198        unsafe {
199            let _: () = msg_send_0(self.as_ptr(), sel!(popDebugGroup));
200        }
201    }
202
203    /// Insert a debug signpost.
204    ///
205    /// C++ equivalent: `void insertDebugSignpost(const NS::String*)`
206    pub fn insert_debug_signpost(&self, name: &str) {
207        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
208            unsafe {
209                let _: () = msg_send_1(self.as_ptr(), sel!(insertDebugSignpost:), ns_name.as_ptr());
210            }
211        }
212    }
213
214    // ========== Encoding ==========
215
216    /// End encoding.
217    ///
218    /// C++ equivalent: `void endEncoding()`
219    pub fn end_encoding(&self) {
220        unsafe {
221            let _: () = msg_send_0(self.as_ptr(), sel!(endEncoding));
222        }
223    }
224}
225
226impl Clone for CommandEncoder {
227    fn clone(&self) -> Self {
228        unsafe {
229            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
230        }
231        Self(self.0)
232    }
233}
234
235impl Drop for CommandEncoder {
236    fn drop(&mut self) {
237        unsafe {
238            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
239        }
240    }
241}
242
243impl Referencing for CommandEncoder {
244    #[inline]
245    fn as_ptr(&self) -> *const c_void {
246        self.0.as_ptr()
247    }
248}
249
250unsafe impl Send for CommandEncoder {}
251unsafe impl Sync for CommandEncoder {}
252
253impl std::fmt::Debug for CommandEncoder {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        f.debug_struct("CommandEncoder")
256            .field("label", &self.label())
257            .finish()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_command_encoder_size() {
267        assert_eq!(
268            std::mem::size_of::<CommandEncoder>(),
269            std::mem::size_of::<*mut c_void>()
270        );
271    }
272}