Skip to main content

mtl_gpu/encoder/compute_encoder/
mod.rs

1//! Compute command encoder.
2//!
3//! Corresponds to `Metal/MTLComputeCommandEncoder.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11mod acceleration;
12mod binding;
13mod dispatch;
14mod indirect;
15mod memory;
16mod pipeline;
17
18/// Indirect arguments for dispatching threadgroups.
19///
20/// C++ equivalent: `MTL::DispatchThreadgroupsIndirectArguments`
21#[repr(C, packed)]
22#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
23pub struct DispatchThreadgroupsIndirectArguments {
24    pub threadgroups_per_grid: [u32; 3],
25}
26
27/// Indirect arguments for dispatching threads.
28///
29/// C++ equivalent: `MTL::DispatchThreadsIndirectArguments`
30#[repr(C, packed)]
31#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
32pub struct DispatchThreadsIndirectArguments {
33    pub threads_per_grid: [u32; 3],
34    pub threads_per_threadgroup: [u32; 3],
35}
36
37/// Indirect arguments for stage-in region.
38///
39/// C++ equivalent: `MTL::StageInRegionIndirectArguments`
40#[repr(C, packed)]
41#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
42pub struct StageInRegionIndirectArguments {
43    pub stage_in_origin: [u32; 3],
44    pub stage_in_size: [u32; 3],
45}
46
47/// A command encoder for compute operations.
48///
49/// C++ equivalent: `MTL::ComputeCommandEncoder`
50///
51/// Compute encoders are used to dispatch compute shaders and manage
52/// resources for compute operations.
53#[repr(transparent)]
54pub struct ComputeCommandEncoder(pub(crate) NonNull<c_void>);
55
56impl ComputeCommandEncoder {
57    /// Create a ComputeCommandEncoder from a raw pointer.
58    ///
59    /// # Safety
60    ///
61    /// The pointer must be a valid Metal compute command encoder object.
62    #[inline]
63    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
64        NonNull::new(ptr).map(Self)
65    }
66
67    /// Get the raw pointer to the encoder.
68    #[inline]
69    pub fn as_raw(&self) -> *mut c_void {
70        self.0.as_ptr()
71    }
72
73    // =========================================================================
74    // CommandEncoder base methods
75    // =========================================================================
76
77    /// Get the device that created this encoder.
78    ///
79    /// C++ equivalent: `Device* device() const`
80    pub fn device(&self) -> crate::Device {
81        unsafe {
82            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
83            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
84            crate::Device::from_raw(ptr).expect("encoder has no device")
85        }
86    }
87
88    /// Get the command buffer that this encoder is encoding into.
89    ///
90    /// C++ equivalent: `CommandBuffer* commandBuffer() const`
91    pub fn command_buffer(&self) -> crate::CommandBuffer {
92        unsafe {
93            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(commandBuffer));
94            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
95            crate::CommandBuffer::from_raw(ptr).expect("encoder has no command buffer")
96        }
97    }
98
99    /// Get the label for this encoder.
100    ///
101    /// C++ equivalent: `NS::String* label() const`
102    pub fn label(&self) -> Option<String> {
103        unsafe {
104            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
105            if ptr.is_null() {
106                return None;
107            }
108            let utf8_ptr: *const std::ffi::c_char =
109                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
110            if utf8_ptr.is_null() {
111                return None;
112            }
113            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
114            Some(c_str.to_string_lossy().into_owned())
115        }
116    }
117
118    /// Set the label for this encoder.
119    ///
120    /// C++ equivalent: `void setLabel(const NS::String*)`
121    pub fn set_label(&self, label: &str) {
122        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
123            unsafe {
124                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
125            }
126        }
127    }
128
129    /// End encoding commands with this encoder.
130    ///
131    /// C++ equivalent: `void endEncoding()`
132    #[inline]
133    pub fn end_encoding(&self) {
134        unsafe {
135            msg_send_0::<()>(self.as_ptr(), sel!(endEncoding));
136        }
137    }
138
139    /// Insert a debug signpost.
140    ///
141    /// C++ equivalent: `void insertDebugSignpost(const NS::String*)`
142    pub fn insert_debug_signpost(&self, string: &str) {
143        if let Some(ns_string) = mtl_foundation::String::from_str(string) {
144            unsafe {
145                msg_send_1::<(), *const c_void>(
146                    self.as_ptr(),
147                    sel!(insertDebugSignpost:),
148                    ns_string.as_ptr(),
149                );
150            }
151        }
152    }
153
154    /// Push a debug group.
155    ///
156    /// C++ equivalent: `void pushDebugGroup(const NS::String*)`
157    pub fn push_debug_group(&self, string: &str) {
158        if let Some(ns_string) = mtl_foundation::String::from_str(string) {
159            unsafe {
160                msg_send_1::<(), *const c_void>(
161                    self.as_ptr(),
162                    sel!(pushDebugGroup:),
163                    ns_string.as_ptr(),
164                );
165            }
166        }
167    }
168
169    /// Pop the current debug group.
170    ///
171    /// C++ equivalent: `void popDebugGroup()`
172    #[inline]
173    pub fn pop_debug_group(&self) {
174        unsafe {
175            msg_send_0::<()>(self.as_ptr(), sel!(popDebugGroup));
176        }
177    }
178
179    /// Insert a barrier to synchronize queue stages.
180    ///
181    /// C++ equivalent: `void barrierAfterQueueStages(Stages, Stages)`
182    #[inline]
183    pub fn barrier_after_queue_stages(
184        &self,
185        after_stages: crate::enums::Stages,
186        before_stages: crate::enums::Stages,
187    ) {
188        unsafe {
189            mtl_sys::msg_send_2::<(), crate::enums::Stages, crate::enums::Stages>(
190                self.as_ptr(),
191                sel!(barrierAfterQueueStages:beforeQueueStages:),
192                after_stages,
193                before_stages,
194            );
195        }
196    }
197}
198
199impl Clone for ComputeCommandEncoder {
200    fn clone(&self) -> Self {
201        unsafe {
202            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
203        }
204        Self(self.0)
205    }
206}
207
208impl Drop for ComputeCommandEncoder {
209    fn drop(&mut self) {
210        unsafe {
211            msg_send_0::<()>(self.as_ptr(), sel!(release));
212        }
213    }
214}
215
216impl Referencing for ComputeCommandEncoder {
217    #[inline]
218    fn as_ptr(&self) -> *const c_void {
219        self.0.as_ptr()
220    }
221}
222
223unsafe impl Send for ComputeCommandEncoder {}
224unsafe impl Sync for ComputeCommandEncoder {}
225
226impl std::fmt::Debug for ComputeCommandEncoder {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("ComputeCommandEncoder")
229            .field("label", &self.label())
230            .finish()
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_compute_encoder_size() {
240        assert_eq!(
241            std::mem::size_of::<ComputeCommandEncoder>(),
242            std::mem::size_of::<*mut c_void>()
243        );
244    }
245
246    #[test]
247    fn test_dispatch_threadgroups_indirect_arguments_size() {
248        assert_eq!(
249            std::mem::size_of::<DispatchThreadgroupsIndirectArguments>(),
250            12
251        ); // 3 * 4 bytes
252    }
253
254    #[test]
255    fn test_dispatch_threads_indirect_arguments_size() {
256        assert_eq!(std::mem::size_of::<DispatchThreadsIndirectArguments>(), 24); // 6 * 4 bytes
257    }
258
259    #[test]
260    fn test_stage_in_region_indirect_arguments_size() {
261        assert_eq!(std::mem::size_of::<StageInRegionIndirectArguments>(), 24); // 6 * 4 bytes
262    }
263}