Skip to main content

mtl_gpu/mtl4/
command_buffer.rs

1//! MTL4 CommandBuffer implementation.
2//!
3//! Corresponds to `Metal/MTL4CommandBuffer.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_5, sel};
10
11use super::acceleration_structure::BufferRange;
12
13use super::CommandAllocator;
14use crate::{Device, ResidencySet};
15
16// ============================================================
17// CommandBufferOptions
18// ============================================================
19
20/// Options for MTL4 command buffer creation.
21///
22/// C++ equivalent: `MTL4::CommandBufferOptions`
23#[repr(transparent)]
24pub struct CommandBufferOptions(NonNull<c_void>);
25
26impl CommandBufferOptions {
27    /// Create a CommandBufferOptions from a raw pointer.
28    #[inline]
29    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30        NonNull::new(ptr).map(Self)
31    }
32
33    /// Get the raw pointer.
34    #[inline]
35    pub fn as_raw(&self) -> *mut c_void {
36        self.0.as_ptr()
37    }
38
39    /// Create new command buffer options.
40    pub fn new() -> Option<Self> {
41        unsafe {
42            let class = mtl_sys::Class::get("MTL4CommandBufferOptions")?;
43            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
44            if ptr.is_null() {
45                return None;
46            }
47            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
48            Self::from_raw(ptr)
49        }
50    }
51
52    /// Get the log state.
53    ///
54    /// C++ equivalent: `MTL::LogState* logState() const`
55    pub fn log_state(&self) -> Option<crate::log_state::LogState> {
56        unsafe {
57            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(logState));
58            crate::log_state::LogState::from_raw(ptr)
59        }
60    }
61
62    /// Set the log state.
63    ///
64    /// C++ equivalent: `void setLogState(MTL::LogState*)`
65    pub fn set_log_state(&self, log_state: &crate::log_state::LogState) {
66        unsafe {
67            let _: () = msg_send_1(self.as_ptr(), sel!(setLogState:), log_state.as_ptr());
68        }
69    }
70}
71
72impl Clone for CommandBufferOptions {
73    fn clone(&self) -> Self {
74        unsafe {
75            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
76        }
77        Self(self.0)
78    }
79}
80
81impl Drop for CommandBufferOptions {
82    fn drop(&mut self) {
83        unsafe {
84            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
85        }
86    }
87}
88
89impl Referencing for CommandBufferOptions {
90    #[inline]
91    fn as_ptr(&self) -> *const c_void {
92        self.0.as_ptr()
93    }
94}
95
96unsafe impl Send for CommandBufferOptions {}
97unsafe impl Sync for CommandBufferOptions {}
98
99// ============================================================
100// CommandBuffer
101// ============================================================
102
103/// MTL4 command buffer for recording GPU commands.
104///
105/// C++ equivalent: `MTL4::CommandBuffer`
106///
107/// CommandBuffer in Metal 4 provides explicit control over command recording
108/// with an allocator-based memory model.
109#[repr(transparent)]
110pub struct CommandBuffer(NonNull<c_void>);
111
112impl CommandBuffer {
113    /// Create a CommandBuffer from a raw pointer.
114    #[inline]
115    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
116        NonNull::new(ptr).map(Self)
117    }
118
119    /// Get the raw pointer.
120    #[inline]
121    pub fn as_raw(&self) -> *mut c_void {
122        self.0.as_ptr()
123    }
124
125    /// Get the device.
126    ///
127    /// C++ equivalent: `MTL::Device* device() const`
128    pub fn device(&self) -> Option<Device> {
129        unsafe {
130            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
131            Device::from_raw(ptr)
132        }
133    }
134
135    /// Get the label.
136    ///
137    /// C++ equivalent: `NS::String* label() const`
138    pub fn label(&self) -> Option<String> {
139        unsafe {
140            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
141            if ns_string.is_null() {
142                return None;
143            }
144            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
145            if c_str.is_null() {
146                return None;
147            }
148            Some(
149                std::ffi::CStr::from_ptr(c_str)
150                    .to_string_lossy()
151                    .into_owned(),
152            )
153        }
154    }
155
156    /// Set the label.
157    ///
158    /// C++ equivalent: `void setLabel(const NS::String* label)`
159    pub fn set_label(&self, label: &str) {
160        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
161            unsafe {
162                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
163            }
164        }
165    }
166
167    /// Begin recording commands with an allocator.
168    ///
169    /// C++ equivalent: `void beginCommandBuffer(const MTL4::CommandAllocator* allocator)`
170    pub fn begin_command_buffer(&self, allocator: &CommandAllocator) {
171        unsafe {
172            let _: () = msg_send_1(
173                self.as_ptr(),
174                sel!(beginCommandBufferWithAllocator:),
175                allocator.as_ptr(),
176            );
177        }
178    }
179
180    /// Begin recording commands with an allocator and options.
181    ///
182    /// C++ equivalent: `void beginCommandBuffer(const MTL4::CommandAllocator*, const MTL4::CommandBufferOptions*)`
183    pub fn begin_command_buffer_with_options(
184        &self,
185        allocator: &CommandAllocator,
186        options: &CommandBufferOptions,
187    ) {
188        unsafe {
189            let _: () = msg_send_2(
190                self.as_ptr(),
191                sel!(beginCommandBufferWithAllocator:options:),
192                allocator.as_ptr(),
193                options.as_ptr(),
194            );
195        }
196    }
197
198    /// End recording commands.
199    ///
200    /// C++ equivalent: `void endCommandBuffer()`
201    pub fn end_command_buffer(&self) {
202        unsafe {
203            let _: () = msg_send_0(self.as_ptr(), sel!(endCommandBuffer));
204        }
205    }
206
207    /// Use a residency set.
208    ///
209    /// C++ equivalent: `void useResidencySet(const MTL::ResidencySet* residencySet)`
210    pub fn use_residency_set(&self, residency_set: &ResidencySet) {
211        unsafe {
212            let _: () = msg_send_1(
213                self.as_ptr(),
214                sel!(useResidencySet:),
215                residency_set.as_ptr(),
216            );
217        }
218    }
219
220    /// Use multiple residency sets.
221    ///
222    /// C++ equivalent: `void useResidencySets(const MTL::ResidencySet* const[], NS::UInteger count)`
223    pub fn use_residency_sets(&self, residency_sets: &[&ResidencySet]) {
224        let ptrs: Vec<*const c_void> = residency_sets.iter().map(|r| r.as_ptr()).collect();
225        unsafe {
226            let _: () = msg_send_2(
227                self.as_ptr(),
228                sel!(useResidencySets:count:),
229                ptrs.as_ptr(),
230                ptrs.len() as UInteger,
231            );
232        }
233    }
234
235    /// Push a debug group.
236    ///
237    /// C++ equivalent: `void pushDebugGroup(const NS::String* string)`
238    pub fn push_debug_group(&self, name: &str) {
239        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
240            unsafe {
241                let _: () = msg_send_1(self.as_ptr(), sel!(pushDebugGroup:), ns_name.as_ptr());
242            }
243        }
244    }
245
246    /// Pop the current debug group.
247    ///
248    /// C++ equivalent: `void popDebugGroup()`
249    pub fn pop_debug_group(&self) {
250        unsafe {
251            let _: () = msg_send_0(self.as_ptr(), sel!(popDebugGroup));
252        }
253    }
254
255    // =========================================================================
256    // Command Encoder Creation
257    // =========================================================================
258
259    /// Create a compute command encoder.
260    ///
261    /// C++ equivalent: `ComputeCommandEncoder* computeCommandEncoder()`
262    pub fn compute_command_encoder(&self) -> Option<super::ComputeCommandEncoder> {
263        unsafe {
264            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(computeCommandEncoder));
265            if ptr.is_null() {
266                None
267            } else {
268                super::ComputeCommandEncoder::from_raw(ptr)
269            }
270        }
271    }
272
273    /// Create a render command encoder with the specified render pass descriptor.
274    ///
275    /// C++ equivalent: `RenderCommandEncoder* renderCommandEncoder(const RenderPassDescriptor*)`
276    pub fn render_command_encoder(
277        &self,
278        descriptor: &super::RenderPassDescriptor,
279    ) -> Option<super::RenderCommandEncoder> {
280        unsafe {
281            let ptr: *mut c_void = msg_send_1(
282                self.as_ptr(),
283                sel!(renderCommandEncoderWithDescriptor:),
284                descriptor.as_ptr(),
285            );
286            if ptr.is_null() {
287                None
288            } else {
289                super::RenderCommandEncoder::from_raw(ptr)
290            }
291        }
292    }
293
294    /// Create a machine learning command encoder.
295    ///
296    /// C++ equivalent: `MachineLearningCommandEncoder* machineLearningCommandEncoder()`
297    pub fn machine_learning_command_encoder(&self) -> Option<super::MachineLearningCommandEncoder> {
298        unsafe {
299            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(machineLearningCommandEncoder));
300            if ptr.is_null() {
301                None
302            } else {
303                super::MachineLearningCommandEncoder::from_raw(ptr)
304            }
305        }
306    }
307
308    // =========================================================================
309    // Counter/Timestamp Methods
310    // =========================================================================
311
312    /// Resolve counter heap data into a buffer.
313    ///
314    /// C++ equivalent: `void resolveCounterHeap(const MTL4::CounterHeap*, NS::Range, const MTL4::BufferRange, const MTL::Fence*, const MTL::Fence*)`
315    pub fn resolve_counter_heap(
316        &self,
317        counter_heap: *const c_void,
318        range_location: UInteger,
319        range_length: UInteger,
320        buffer_range: BufferRange,
321        fence_to_wait: *const c_void,
322        fence_to_update: *const c_void,
323    ) {
324        unsafe {
325            let range = (range_location, range_length);
326            let _: () = msg_send_5(
327                self.as_ptr(),
328                sel!(resolveCounterHeap:withRange:intoBuffer:waitFence:updateFence:),
329                counter_heap,
330                range,
331                buffer_range,
332                fence_to_wait,
333                fence_to_update,
334            );
335        }
336    }
337
338    /// Write a timestamp into a counter heap.
339    ///
340    /// C++ equivalent: `void writeTimestampIntoHeap(const MTL4::CounterHeap*, NS::UInteger)`
341    pub fn write_timestamp_into_heap(&self, counter_heap: *const c_void, index: UInteger) {
342        unsafe {
343            let _: () = msg_send_2(
344                self.as_ptr(),
345                sel!(writeTimestampIntoHeap:atIndex:),
346                counter_heap,
347                index,
348            );
349        }
350    }
351}
352
353impl Clone for CommandBuffer {
354    fn clone(&self) -> Self {
355        unsafe {
356            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
357        }
358        Self(self.0)
359    }
360}
361
362impl Drop for CommandBuffer {
363    fn drop(&mut self) {
364        unsafe {
365            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
366        }
367    }
368}
369
370impl Referencing for CommandBuffer {
371    #[inline]
372    fn as_ptr(&self) -> *const c_void {
373        self.0.as_ptr()
374    }
375}
376
377unsafe impl Send for CommandBuffer {}
378unsafe impl Sync for CommandBuffer {}
379
380impl std::fmt::Debug for CommandBuffer {
381    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382        f.debug_struct("CommandBuffer")
383            .field("label", &self.label())
384            .finish()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_command_buffer_options_size() {
394        assert_eq!(
395            std::mem::size_of::<CommandBufferOptions>(),
396            std::mem::size_of::<*mut c_void>()
397        );
398    }
399
400    #[test]
401    fn test_command_buffer_size() {
402        assert_eq!(
403            std::mem::size_of::<CommandBuffer>(),
404            std::mem::size_of::<*mut c_void>()
405        );
406    }
407}