Skip to main content

mtl_gpu/io/
command_buffer.rs

1//! IO command buffer for Metal.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_5, msg_send_9, sel};
8
9use crate::buffer::Buffer;
10use crate::enums::IOStatus;
11use crate::sync::SharedEvent;
12use crate::texture::Texture;
13use crate::types::{Origin, Size};
14
15use super::IOFileHandle;
16
17/// Command buffer for IO operations.
18///
19/// C++ equivalent: `MTL::IOCommandBuffer`
20#[repr(transparent)]
21pub struct IOCommandBuffer(pub(crate) NonNull<c_void>);
22
23impl IOCommandBuffer {
24    /// Create from a raw pointer.
25    ///
26    /// # Safety
27    ///
28    /// The pointer must be a valid Metal IO command buffer.
29    #[inline]
30    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
31        NonNull::new(ptr).map(Self)
32    }
33
34    /// Get the raw pointer.
35    #[inline]
36    pub fn as_raw(&self) -> *mut c_void {
37        self.0.as_ptr()
38    }
39
40    /// Add a barrier.
41    ///
42    /// C++ equivalent: `void addBarrier()`
43    #[inline]
44    pub fn add_barrier(&self) {
45        unsafe {
46            msg_send_0::<()>(self.as_ptr(), sel!(addBarrier));
47        }
48    }
49
50    /// Commit the command buffer.
51    ///
52    /// C++ equivalent: `void commit()`
53    #[inline]
54    pub fn commit(&self) {
55        unsafe {
56            msg_send_0::<()>(self.as_ptr(), sel!(commit));
57        }
58    }
59
60    /// Copy the status to a buffer.
61    ///
62    /// C++ equivalent: `void copyStatusToBuffer(const Buffer*, NS::UInteger)`
63    pub fn copy_status_to_buffer(&self, buffer: &Buffer, offset: UInteger) {
64        unsafe {
65            msg_send_2::<(), *const c_void, UInteger>(
66                self.as_ptr(),
67                sel!(copyStatusToBuffer:offset:),
68                buffer.as_ptr(),
69                offset,
70            );
71        }
72    }
73
74    /// Enqueue the command buffer.
75    ///
76    /// C++ equivalent: `void enqueue()`
77    #[inline]
78    pub fn enqueue(&self) {
79        unsafe {
80            msg_send_0::<()>(self.as_ptr(), sel!(enqueue));
81        }
82    }
83
84    /// Get the error, if any.
85    ///
86    /// C++ equivalent: `NS::Error* error() const`
87    pub fn error(&self) -> Option<mtl_foundation::Error> {
88        unsafe {
89            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(error));
90            mtl_foundation::Error::from_ptr(ptr)
91        }
92    }
93
94    /// Get the label.
95    ///
96    /// C++ equivalent: `NS::String* label() const`
97    pub fn label(&self) -> Option<String> {
98        unsafe {
99            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
100            if ptr.is_null() {
101                return None;
102            }
103            let utf8_ptr: *const std::ffi::c_char =
104                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
105            if utf8_ptr.is_null() {
106                return None;
107            }
108            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
109            Some(c_str.to_string_lossy().into_owned())
110        }
111    }
112
113    /// Set the label.
114    ///
115    /// C++ equivalent: `void setLabel(const NS::String*)`
116    pub fn set_label(&self, label: &str) {
117        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
118            unsafe {
119                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
120            }
121        }
122    }
123
124    /// Load data from a file into a buffer.
125    ///
126    /// C++ equivalent: `void loadBuffer(const Buffer*, NS::UInteger, NS::UInteger, const IOFileHandle*, NS::UInteger)`
127    pub fn load_buffer(
128        &self,
129        buffer: &Buffer,
130        offset: UInteger,
131        size: UInteger,
132        source_handle: &IOFileHandle,
133        source_handle_offset: UInteger,
134    ) {
135        unsafe {
136            msg_send_5::<(), *const c_void, UInteger, UInteger, *const c_void, UInteger>(
137                self.as_ptr(),
138                sel!(loadBuffer:offset:size:sourceHandle:sourceHandleOffset:),
139                buffer.as_ptr(),
140                offset,
141                size,
142                source_handle.as_ptr(),
143                source_handle_offset,
144            );
145        }
146    }
147
148    /// Load data from a file into memory.
149    ///
150    /// # Safety
151    ///
152    /// The pointer must point to valid memory of at least `size` bytes.
153    ///
154    /// C++ equivalent: `void loadBytes(const void*, NS::UInteger, const IOFileHandle*, NS::UInteger)`
155    pub unsafe fn load_bytes(
156        &self,
157        pointer: *mut c_void,
158        size: UInteger,
159        source_handle: &IOFileHandle,
160        source_handle_offset: UInteger,
161    ) {
162        unsafe {
163            mtl_sys::msg_send_4::<(), *mut c_void, UInteger, *const c_void, UInteger>(
164                self.as_ptr(),
165                sel!(loadBytes:size:sourceHandle:sourceHandleOffset:),
166                pointer,
167                size,
168                source_handle.as_ptr(),
169                source_handle_offset,
170            );
171        }
172    }
173
174    /// Load data from a file into a texture.
175    ///
176    /// C++ equivalent: `void loadTexture(const Texture*, NS::UInteger, NS::UInteger, Size, NS::UInteger, NS::UInteger, Origin, const IOFileHandle*, NS::UInteger)`
177    #[allow(clippy::too_many_arguments)]
178    pub fn load_texture(
179        &self,
180        texture: &Texture,
181        slice: UInteger,
182        level: UInteger,
183        size: Size,
184        source_bytes_per_row: UInteger,
185        source_bytes_per_image: UInteger,
186        destination_origin: Origin,
187        source_handle: &IOFileHandle,
188        source_handle_offset: UInteger,
189    ) {
190        unsafe {
191            msg_send_9::<
192                (),
193                *const c_void,
194                UInteger,
195                UInteger,
196                Size,
197                UInteger,
198                UInteger,
199                Origin,
200                *const c_void,
201                UInteger,
202            >(
203                self.as_ptr(),
204                sel!(loadTexture:slice:level:size:sourceBytesPerRow:sourceBytesPerImage:destinationOrigin:sourceHandle:sourceHandleOffset:),
205                texture.as_ptr(),
206                slice,
207                level,
208                size,
209                source_bytes_per_row,
210                source_bytes_per_image,
211                destination_origin,
212                source_handle.as_ptr(),
213                source_handle_offset,
214            );
215        }
216    }
217
218    /// Pop a debug group.
219    ///
220    /// C++ equivalent: `void popDebugGroup()`
221    #[inline]
222    pub fn pop_debug_group(&self) {
223        unsafe {
224            msg_send_0::<()>(self.as_ptr(), sel!(popDebugGroup));
225        }
226    }
227
228    /// Push a debug group.
229    ///
230    /// C++ equivalent: `void pushDebugGroup(const NS::String*)`
231    pub fn push_debug_group(&self, name: &str) {
232        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
233            unsafe {
234                msg_send_1::<(), *const c_void>(
235                    self.as_ptr(),
236                    sel!(pushDebugGroup:),
237                    ns_name.as_ptr(),
238                );
239            }
240        }
241    }
242
243    /// Signal a shared event.
244    ///
245    /// C++ equivalent: `void signalEvent(const SharedEvent*, uint64_t)`
246    pub fn signal_event(&self, event: &SharedEvent, value: u64) {
247        unsafe {
248            msg_send_2::<(), *const c_void, u64>(
249                self.as_ptr(),
250                sel!(signalEvent:value:),
251                event.as_ptr(),
252                value,
253            );
254        }
255    }
256
257    /// Get the status.
258    ///
259    /// C++ equivalent: `IOStatus status() const`
260    #[inline]
261    pub fn status(&self) -> IOStatus {
262        unsafe { msg_send_0(self.as_ptr(), sel!(status)) }
263    }
264
265    /// Try to cancel the command buffer.
266    ///
267    /// C++ equivalent: `void tryCancel()`
268    #[inline]
269    pub fn try_cancel(&self) {
270        unsafe {
271            msg_send_0::<()>(self.as_ptr(), sel!(tryCancel));
272        }
273    }
274
275    /// Wait for a shared event to reach a value.
276    ///
277    /// C++ equivalent: `void wait(const SharedEvent*, uint64_t)`
278    pub fn wait(&self, event: &SharedEvent, value: u64) {
279        unsafe {
280            msg_send_2::<(), *const c_void, u64>(
281                self.as_ptr(),
282                sel!(waitForEvent:value:),
283                event.as_ptr(),
284                value,
285            );
286        }
287    }
288
289    /// Wait until the command buffer completes.
290    ///
291    /// C++ equivalent: `void waitUntilCompleted()`
292    #[inline]
293    pub fn wait_until_completed(&self) {
294        unsafe {
295            msg_send_0::<()>(self.as_ptr(), sel!(waitUntilCompleted));
296        }
297    }
298
299    /// Add a completed handler (raw block pointer).
300    ///
301    /// # Safety
302    ///
303    /// The block pointer must be a valid Objective-C block.
304    ///
305    /// C++ equivalent: `void addCompletedHandler(const IOCommandBufferHandler)`
306    pub unsafe fn add_completed_handler_ptr(&self, block: *const c_void) {
307        unsafe {
308            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(addCompletedHandler:), block);
309        }
310    }
311
312    /// Add a handler to be called when the IO command buffer completes.
313    ///
314    /// C++ equivalent: `void addCompletedHandler(const IOCommandBufferHandler)`
315    ///
316    /// The handler is called with a reference to the completed IO command buffer.
317    pub fn add_completed_handler<F>(&self, handler: F)
318    where
319        F: Fn(&IOCommandBuffer) + Send + 'static,
320    {
321        let block = mtl_sys::OneArgBlock::from_fn(move |cmd_buf: *mut c_void| {
322            unsafe {
323                if let Some(buf) = IOCommandBuffer::from_raw(cmd_buf) {
324                    handler(&buf);
325                    // Don't drop - Metal owns this reference
326                    std::mem::forget(buf);
327                }
328            }
329        });
330
331        unsafe {
332            msg_send_1::<(), *const c_void>(
333                self.as_ptr(),
334                sel!(addCompletedHandler:),
335                block.as_ptr(),
336            );
337        }
338
339        // The block is retained by Metal
340        std::mem::forget(block);
341    }
342}
343
344impl Clone for IOCommandBuffer {
345    fn clone(&self) -> Self {
346        unsafe {
347            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
348        }
349        Self(self.0)
350    }
351}
352
353impl Drop for IOCommandBuffer {
354    fn drop(&mut self) {
355        unsafe {
356            msg_send_0::<()>(self.as_ptr(), sel!(release));
357        }
358    }
359}
360
361impl Referencing for IOCommandBuffer {
362    #[inline]
363    fn as_ptr(&self) -> *const c_void {
364        self.0.as_ptr()
365    }
366}
367
368unsafe impl Send for IOCommandBuffer {}
369unsafe impl Sync for IOCommandBuffer {}