mtl_gpu/encoder/compute_encoder/
mod.rs1use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11mod acceleration;
12mod binding;
13mod dispatch;
14mod indirect;
15mod memory;
16mod pipeline;
17
18#[repr(C, packed)]
22#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
23pub struct DispatchThreadgroupsIndirectArguments {
24 pub threadgroups_per_grid: [u32; 3],
25}
26
27#[repr(C, packed)]
31#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
32pub struct DispatchThreadsIndirectArguments {
33 pub threads_per_grid: [u32; 3],
34 pub threads_per_threadgroup: [u32; 3],
35}
36
37#[repr(C, packed)]
41#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
42pub struct StageInRegionIndirectArguments {
43 pub stage_in_origin: [u32; 3],
44 pub stage_in_size: [u32; 3],
45}
46
47#[repr(transparent)]
54pub struct ComputeCommandEncoder(pub(crate) NonNull<c_void>);
55
56impl ComputeCommandEncoder {
57 #[inline]
63 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
64 NonNull::new(ptr).map(Self)
65 }
66
67 #[inline]
69 pub fn as_raw(&self) -> *mut c_void {
70 self.0.as_ptr()
71 }
72
73 pub fn device(&self) -> crate::Device {
81 unsafe {
82 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
83 let _: *mut c_void = msg_send_0(ptr, sel!(retain));
84 crate::Device::from_raw(ptr).expect("encoder has no device")
85 }
86 }
87
88 pub fn command_buffer(&self) -> crate::CommandBuffer {
92 unsafe {
93 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(commandBuffer));
94 let _: *mut c_void = msg_send_0(ptr, sel!(retain));
95 crate::CommandBuffer::from_raw(ptr).expect("encoder has no command buffer")
96 }
97 }
98
99 pub fn label(&self) -> Option<String> {
103 unsafe {
104 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
105 if ptr.is_null() {
106 return None;
107 }
108 let utf8_ptr: *const std::ffi::c_char =
109 mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
110 if utf8_ptr.is_null() {
111 return None;
112 }
113 let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
114 Some(c_str.to_string_lossy().into_owned())
115 }
116 }
117
118 pub fn set_label(&self, label: &str) {
122 if let Some(ns_label) = mtl_foundation::String::from_str(label) {
123 unsafe {
124 msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
125 }
126 }
127 }
128
129 #[inline]
133 pub fn end_encoding(&self) {
134 unsafe {
135 msg_send_0::<()>(self.as_ptr(), sel!(endEncoding));
136 }
137 }
138
139 pub fn insert_debug_signpost(&self, string: &str) {
143 if let Some(ns_string) = mtl_foundation::String::from_str(string) {
144 unsafe {
145 msg_send_1::<(), *const c_void>(
146 self.as_ptr(),
147 sel!(insertDebugSignpost:),
148 ns_string.as_ptr(),
149 );
150 }
151 }
152 }
153
154 pub fn push_debug_group(&self, string: &str) {
158 if let Some(ns_string) = mtl_foundation::String::from_str(string) {
159 unsafe {
160 msg_send_1::<(), *const c_void>(
161 self.as_ptr(),
162 sel!(pushDebugGroup:),
163 ns_string.as_ptr(),
164 );
165 }
166 }
167 }
168
169 #[inline]
173 pub fn pop_debug_group(&self) {
174 unsafe {
175 msg_send_0::<()>(self.as_ptr(), sel!(popDebugGroup));
176 }
177 }
178
179 #[inline]
183 pub fn barrier_after_queue_stages(
184 &self,
185 after_stages: crate::enums::Stages,
186 before_stages: crate::enums::Stages,
187 ) {
188 unsafe {
189 mtl_sys::msg_send_2::<(), crate::enums::Stages, crate::enums::Stages>(
190 self.as_ptr(),
191 sel!(barrierAfterQueueStages:beforeQueueStages:),
192 after_stages,
193 before_stages,
194 );
195 }
196 }
197}
198
199impl Clone for ComputeCommandEncoder {
200 fn clone(&self) -> Self {
201 unsafe {
202 msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
203 }
204 Self(self.0)
205 }
206}
207
208impl Drop for ComputeCommandEncoder {
209 fn drop(&mut self) {
210 unsafe {
211 msg_send_0::<()>(self.as_ptr(), sel!(release));
212 }
213 }
214}
215
216impl Referencing for ComputeCommandEncoder {
217 #[inline]
218 fn as_ptr(&self) -> *const c_void {
219 self.0.as_ptr()
220 }
221}
222
223unsafe impl Send for ComputeCommandEncoder {}
224unsafe impl Sync for ComputeCommandEncoder {}
225
226impl std::fmt::Debug for ComputeCommandEncoder {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("ComputeCommandEncoder")
229 .field("label", &self.label())
230 .finish()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_compute_encoder_size() {
240 assert_eq!(
241 std::mem::size_of::<ComputeCommandEncoder>(),
242 std::mem::size_of::<*mut c_void>()
243 );
244 }
245
246 #[test]
247 fn test_dispatch_threadgroups_indirect_arguments_size() {
248 assert_eq!(
249 std::mem::size_of::<DispatchThreadgroupsIndirectArguments>(),
250 12
251 ); }
253
254 #[test]
255 fn test_dispatch_threads_indirect_arguments_size() {
256 assert_eq!(std::mem::size_of::<DispatchThreadsIndirectArguments>(), 24); }
258
259 #[test]
260 fn test_stage_in_region_indirect_arguments_size() {
261 assert_eq!(std::mem::size_of::<StageInRegionIndirectArguments>(), 24); }
263}