Skip to main content

mtl_gpu/indirect/
compute_command.rs

1//! A compute command within an indirect command buffer.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
8
9use crate::types::{Region, Size};
10use crate::{Buffer, ComputePipelineState};
11
12/// A compute command within an indirect command buffer.
13///
14/// C++ equivalent: `MTL::IndirectComputeCommand`
15///
16/// Indirect compute commands can encode dispatch calls and state changes
17/// that will be executed when the indirect command buffer is executed.
18#[repr(transparent)]
19pub struct IndirectComputeCommand(NonNull<c_void>);
20
21impl IndirectComputeCommand {
22    /// Create from a raw pointer.
23    ///
24    /// # Safety
25    ///
26    /// The pointer must be a valid Metal indirect compute command.
27    #[inline]
28    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
29        NonNull::new(ptr).map(Self)
30    }
31
32    /// Get the raw pointer.
33    #[inline]
34    pub fn as_raw(&self) -> *mut c_void {
35        self.0.as_ptr()
36    }
37
38    /// Reset this command.
39    ///
40    /// C++ equivalent: `void reset()`
41    #[inline]
42    pub fn reset(&self) {
43        unsafe {
44            msg_send_0::<()>(self.as_ptr(), sel!(reset));
45        }
46    }
47
48    /// Set a barrier for this command.
49    ///
50    /// C++ equivalent: `void setBarrier()`
51    #[inline]
52    pub fn set_barrier(&self) {
53        unsafe {
54            msg_send_0::<()>(self.as_ptr(), sel!(setBarrier));
55        }
56    }
57
58    /// Clear a barrier for this command.
59    ///
60    /// C++ equivalent: `void clearBarrier()`
61    #[inline]
62    pub fn clear_barrier(&self) {
63        unsafe {
64            msg_send_0::<()>(self.as_ptr(), sel!(clearBarrier));
65        }
66    }
67
68    /// Set the compute pipeline state.
69    ///
70    /// C++ equivalent: `void setComputePipelineState(const ComputePipelineState*)`
71    pub fn set_compute_pipeline_state(&self, state: &ComputePipelineState) {
72        unsafe {
73            msg_send_1::<(), *const c_void>(
74                self.as_ptr(),
75                sel!(setComputePipelineState:),
76                state.as_ptr(),
77            );
78        }
79    }
80
81    /// Set a kernel buffer.
82    ///
83    /// C++ equivalent: `void setKernelBuffer(const Buffer*, NS::UInteger offset, NS::UInteger index)`
84    pub fn set_kernel_buffer(&self, buffer: &Buffer, offset: UInteger, index: UInteger) {
85        unsafe {
86            mtl_sys::msg_send_3::<(), *const c_void, UInteger, UInteger>(
87                self.as_ptr(),
88                sel!(setKernelBuffer: offset: atIndex:),
89                buffer.as_ptr(),
90                offset,
91                index,
92            );
93        }
94    }
95
96    /// Set a kernel buffer with stride.
97    ///
98    /// C++ equivalent: `void setKernelBuffer(const Buffer*, NS::UInteger, NS::UInteger, NS::UInteger)`
99    pub fn set_kernel_buffer_with_stride(
100        &self,
101        buffer: &Buffer,
102        offset: UInteger,
103        stride: UInteger,
104        index: UInteger,
105    ) {
106        unsafe {
107            mtl_sys::msg_send_4::<(), *const c_void, UInteger, UInteger, UInteger>(
108                self.as_ptr(),
109                sel!(setKernelBuffer: offset: attributeStride: atIndex:),
110                buffer.as_ptr(),
111                offset,
112                stride,
113                index,
114            );
115        }
116    }
117
118    /// Set threadgroup memory length.
119    ///
120    /// C++ equivalent: `void setThreadgroupMemoryLength(NS::UInteger, NS::UInteger)`
121    pub fn set_threadgroup_memory_length(&self, length: UInteger, index: UInteger) {
122        unsafe {
123            msg_send_2::<(), UInteger, UInteger>(
124                self.as_ptr(),
125                sel!(setThreadgroupMemoryLength: atIndex:),
126                length,
127                index,
128            );
129        }
130    }
131
132    /// Set the stage-in region.
133    ///
134    /// C++ equivalent: `void setStageInRegion(Region)`
135    #[inline]
136    pub fn set_stage_in_region(&self, region: Region) {
137        unsafe {
138            msg_send_1::<(), Region>(self.as_ptr(), sel!(setStageInRegion:), region);
139        }
140    }
141
142    /// Set the imageblock dimensions.
143    ///
144    /// C++ equivalent: `void setImageblockWidth(NS::UInteger, NS::UInteger)`
145    pub fn set_imageblock_width(&self, width: UInteger, height: UInteger) {
146        unsafe {
147            msg_send_2::<(), UInteger, UInteger>(
148                self.as_ptr(),
149                sel!(setImageblockWidth: height:),
150                width,
151                height,
152            );
153        }
154    }
155
156    /// Dispatch threadgroups concurrently.
157    ///
158    /// C++ equivalent: `void concurrentDispatchThreadgroups(...)`
159    pub fn concurrent_dispatch_threadgroups(
160        &self,
161        threadgroups_per_grid: Size,
162        threads_per_threadgroup: Size,
163    ) {
164        unsafe {
165            msg_send_2::<(), Size, Size>(
166                self.as_ptr(),
167                sel!(concurrentDispatchThreadgroups: threadsPerThreadgroup:),
168                threadgroups_per_grid,
169                threads_per_threadgroup,
170            );
171        }
172    }
173
174    /// Dispatch threads concurrently.
175    ///
176    /// C++ equivalent: `void concurrentDispatchThreads(...)`
177    pub fn concurrent_dispatch_threads(
178        &self,
179        threads_per_grid: Size,
180        threads_per_threadgroup: Size,
181    ) {
182        unsafe {
183            msg_send_2::<(), Size, Size>(
184                self.as_ptr(),
185                sel!(concurrentDispatchThreads: threadsPerThreadgroup:),
186                threads_per_grid,
187                threads_per_threadgroup,
188            );
189        }
190    }
191}
192
193impl Referencing for IndirectComputeCommand {
194    #[inline]
195    fn as_ptr(&self) -> *const c_void {
196        self.0.as_ptr()
197    }
198}
199
200// Note: IndirectComputeCommand is not reference counted - it's a view into the ICB
201
202impl std::fmt::Debug for IndirectComputeCommand {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("IndirectComputeCommand").finish()
205    }
206}