Skip to main content

mtl_gpu/acceleration/
encoder.rs

1//! Acceleration structure command encoder.
2//!
3//! Contains `AccelerationStructureCommandEncoder`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{
10    msg_send_0, msg_send_1, msg_send_2, msg_send_3, msg_send_4, msg_send_5, msg_send_6, sel,
11};
12
13use crate::enums::{AccelerationStructureRefitOptions, DataType, ResourceUsage};
14use crate::{Buffer, Fence, Heap};
15
16use super::{AccelerationStructure, AccelerationStructureDescriptor};
17
18pub struct AccelerationStructureCommandEncoder(pub(crate) NonNull<c_void>);
19
20impl AccelerationStructureCommandEncoder {
21    /// Create from a raw pointer.
22    ///
23    /// # Safety
24    ///
25    /// The pointer must be a valid Metal acceleration structure command encoder.
26    #[inline]
27    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
28        NonNull::new(ptr).map(Self)
29    }
30
31    /// Get the raw pointer.
32    #[inline]
33    pub fn as_raw(&self) -> *mut c_void {
34        self.0.as_ptr()
35    }
36
37    // CommandEncoder base methods
38
39    /// Get the label.
40    ///
41    /// C++ equivalent: `NS::String* label() const`
42    pub fn label(&self) -> Option<String> {
43        unsafe {
44            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
45            if ptr.is_null() {
46                return None;
47            }
48            let utf8_ptr: *const std::ffi::c_char =
49                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
50            if utf8_ptr.is_null() {
51                return None;
52            }
53            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
54            Some(c_str.to_string_lossy().into_owned())
55        }
56    }
57
58    /// Set the label.
59    ///
60    /// C++ equivalent: `void setLabel(const NS::String*)`
61    pub fn set_label(&self, label: &str) {
62        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
63            unsafe {
64                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
65            }
66        }
67    }
68
69    /// End encoding commands.
70    ///
71    /// C++ equivalent: `void endEncoding()`
72    #[inline]
73    pub fn end_encoding(&self) {
74        unsafe {
75            msg_send_0::<()>(self.as_ptr(), sel!(endEncoding));
76        }
77    }
78
79    /// Insert a debug signpost.
80    ///
81    /// C++ equivalent: `void insertDebugSignpost(const NS::String*)`
82    pub fn insert_debug_signpost(&self, name: &str) {
83        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
84            unsafe {
85                msg_send_1::<(), *const c_void>(
86                    self.as_ptr(),
87                    sel!(insertDebugSignpost:),
88                    ns_name.as_ptr(),
89                );
90            }
91        }
92    }
93
94    /// Push a debug group.
95    ///
96    /// C++ equivalent: `void pushDebugGroup(const NS::String*)`
97    pub fn push_debug_group(&self, name: &str) {
98        if let Some(ns_name) = mtl_foundation::String::from_str(name) {
99            unsafe {
100                msg_send_1::<(), *const c_void>(
101                    self.as_ptr(),
102                    sel!(pushDebugGroup:),
103                    ns_name.as_ptr(),
104                );
105            }
106        }
107    }
108
109    /// Pop a debug group.
110    ///
111    /// C++ equivalent: `void popDebugGroup()`
112    #[inline]
113    pub fn pop_debug_group(&self) {
114        unsafe {
115            msg_send_0::<()>(self.as_ptr(), sel!(popDebugGroup));
116        }
117    }
118
119    // Acceleration structure methods
120
121    /// Build an acceleration structure.
122    ///
123    /// C++ equivalent: `void buildAccelerationStructure(AccelerationStructure*, AccelerationStructureDescriptor*, Buffer*, NS::UInteger)`
124    pub fn build_acceleration_structure(
125        &self,
126        acceleration_structure: &AccelerationStructure,
127        descriptor: &AccelerationStructureDescriptor,
128        scratch_buffer: &Buffer,
129        scratch_buffer_offset: UInteger,
130    ) {
131        unsafe {
132            msg_send_4::<(), *const c_void, *const c_void, *const c_void, UInteger>(
133                self.as_ptr(),
134                sel!(buildAccelerationStructure:descriptor:scratchBuffer:scratchBufferOffset:),
135                acceleration_structure.as_ptr(),
136                descriptor.as_ptr(),
137                scratch_buffer.as_ptr(),
138                scratch_buffer_offset,
139            );
140        }
141    }
142
143    /// Copy an acceleration structure.
144    ///
145    /// C++ equivalent: `void copyAccelerationStructure(AccelerationStructure*, AccelerationStructure*)`
146    pub fn copy_acceleration_structure(
147        &self,
148        source: &AccelerationStructure,
149        destination: &AccelerationStructure,
150    ) {
151        unsafe {
152            msg_send_2::<(), *const c_void, *const c_void>(
153                self.as_ptr(),
154                sel!(copyAccelerationStructure:toAccelerationStructure:),
155                source.as_ptr(),
156                destination.as_ptr(),
157            );
158        }
159    }
160
161    /// Copy and compact an acceleration structure.
162    ///
163    /// C++ equivalent: `void copyAndCompactAccelerationStructure(AccelerationStructure*, AccelerationStructure*)`
164    pub fn copy_and_compact_acceleration_structure(
165        &self,
166        source: &AccelerationStructure,
167        destination: &AccelerationStructure,
168    ) {
169        unsafe {
170            msg_send_2::<(), *const c_void, *const c_void>(
171                self.as_ptr(),
172                sel!(copyAndCompactAccelerationStructure:toAccelerationStructure:),
173                source.as_ptr(),
174                destination.as_ptr(),
175            );
176        }
177    }
178
179    /// Refit an acceleration structure.
180    ///
181    /// C++ equivalent: `void refitAccelerationStructure(AccelerationStructure*, AccelerationStructureDescriptor*, AccelerationStructure*, Buffer*, NS::UInteger)`
182    pub fn refit_acceleration_structure(
183        &self,
184        source: &AccelerationStructure,
185        descriptor: &AccelerationStructureDescriptor,
186        destination: &AccelerationStructure,
187        scratch_buffer: &Buffer,
188        scratch_buffer_offset: UInteger,
189    ) {
190        unsafe {
191            msg_send_5::<(), *const c_void, *const c_void, *const c_void, *const c_void, UInteger>(
192                self.as_ptr(),
193                sel!(refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:),
194                source.as_ptr(),
195                descriptor.as_ptr(),
196                destination.as_ptr(),
197                scratch_buffer.as_ptr(),
198                scratch_buffer_offset,
199            );
200        }
201    }
202
203    /// Refit an acceleration structure with options.
204    ///
205    /// C++ equivalent: `void refitAccelerationStructure(AccelerationStructure*, AccelerationStructureDescriptor*, AccelerationStructure*, Buffer*, NS::UInteger, AccelerationStructureRefitOptions)`
206    pub fn refit_acceleration_structure_with_options(
207        &self,
208        source: &AccelerationStructure,
209        descriptor: &AccelerationStructureDescriptor,
210        destination: &AccelerationStructure,
211        scratch_buffer: &Buffer,
212        scratch_buffer_offset: UInteger,
213        options: AccelerationStructureRefitOptions,
214    ) {
215        unsafe {
216            msg_send_6::<
217                (),
218                *const c_void,
219                *const c_void,
220                *const c_void,
221                *const c_void,
222                UInteger,
223                AccelerationStructureRefitOptions,
224            >(
225                self.as_ptr(),
226                sel!(refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:options:),
227                source.as_ptr(),
228                descriptor.as_ptr(),
229                destination.as_ptr(),
230                scratch_buffer.as_ptr(),
231                scratch_buffer_offset,
232                options,
233            );
234        }
235    }
236
237    /// Write the compacted size of an acceleration structure.
238    ///
239    /// C++ equivalent: `void writeCompactedAccelerationStructureSize(AccelerationStructure*, Buffer*, NS::UInteger)`
240    pub fn write_compacted_acceleration_structure_size(
241        &self,
242        acceleration_structure: &AccelerationStructure,
243        buffer: &Buffer,
244        offset: UInteger,
245    ) {
246        unsafe {
247            msg_send_3::<(), *const c_void, *const c_void, UInteger>(
248                self.as_ptr(),
249                sel!(writeCompactedAccelerationStructureSize:toBuffer:offset:),
250                acceleration_structure.as_ptr(),
251                buffer.as_ptr(),
252                offset,
253            );
254        }
255    }
256
257    /// Write the compacted size of an acceleration structure with size data type.
258    ///
259    /// C++ equivalent: `void writeCompactedAccelerationStructureSize(AccelerationStructure*, Buffer*, NS::UInteger, DataType)`
260    pub fn write_compacted_acceleration_structure_size_with_type(
261        &self,
262        acceleration_structure: &AccelerationStructure,
263        buffer: &Buffer,
264        offset: UInteger,
265        size_data_type: DataType,
266    ) {
267        unsafe {
268            msg_send_4::<(), *const c_void, *const c_void, UInteger, DataType>(
269                self.as_ptr(),
270                sel!(writeCompactedAccelerationStructureSize:toBuffer:offset:sizeDataType:),
271                acceleration_structure.as_ptr(),
272                buffer.as_ptr(),
273                offset,
274                size_data_type,
275            );
276        }
277    }
278
279    // Counter sampling methods
280
281    /// Sample counters in a buffer.
282    ///
283    /// # Safety
284    ///
285    /// The sample_buffer pointer must be valid.
286    ///
287    /// C++ equivalent: `void sampleCountersInBuffer(CounterSampleBuffer*, NS::UInteger, bool)`
288    pub unsafe fn sample_counters_in_buffer_ptr(
289        &self,
290        sample_buffer: *const c_void,
291        sample_index: UInteger,
292        barrier: bool,
293    ) {
294        unsafe {
295            msg_send_3::<(), *const c_void, UInteger, bool>(
296                self.as_ptr(),
297                sel!(sampleCountersInBuffer:atSampleIndex:withBarrier:),
298                sample_buffer,
299                sample_index,
300                barrier,
301            );
302        }
303    }
304
305    // Fence methods
306
307    /// Wait for a fence.
308    ///
309    /// C++ equivalent: `void waitForFence(Fence*)`
310    pub fn wait_for_fence(&self, fence: &Fence) {
311        unsafe {
312            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(waitForFence:), fence.as_ptr());
313        }
314    }
315
316    /// Update a fence.
317    ///
318    /// C++ equivalent: `void updateFence(Fence*)`
319    pub fn update_fence(&self, fence: &Fence) {
320        unsafe {
321            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(updateFence:), fence.as_ptr());
322        }
323    }
324
325    // Resource methods
326
327    /// Use a resource.
328    ///
329    /// C++ equivalent: `void useResource(Resource*, ResourceUsage)`
330    pub fn use_resource<R: Referencing>(&self, resource: &R, usage: ResourceUsage) {
331        unsafe {
332            msg_send_2::<(), *const c_void, ResourceUsage>(
333                self.as_ptr(),
334                sel!(useResource:usage:),
335                resource.as_ptr(),
336                usage,
337            );
338        }
339    }
340
341    /// Use multiple resources.
342    ///
343    /// # Safety
344    ///
345    /// The resources pointer must be valid for the given count.
346    pub unsafe fn use_resources_ptr(
347        &self,
348        resources: *const *const c_void,
349        count: UInteger,
350        usage: ResourceUsage,
351    ) {
352        unsafe {
353            msg_send_3::<(), *const *const c_void, UInteger, ResourceUsage>(
354                self.as_ptr(),
355                sel!(useResources:count:usage:),
356                resources,
357                count,
358                usage,
359            );
360        }
361    }
362
363    /// Use a heap.
364    ///
365    /// C++ equivalent: `void useHeap(Heap*)`
366    pub fn use_heap(&self, heap: &Heap) {
367        unsafe {
368            msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(useHeap:), heap.as_ptr());
369        }
370    }
371
372    /// Use multiple heaps.
373    ///
374    /// # Safety
375    ///
376    /// The heaps pointer must be valid for the given count.
377    pub unsafe fn use_heaps_ptr(&self, heaps: *const *const c_void, count: UInteger) {
378        unsafe {
379            msg_send_2::<(), *const *const c_void, UInteger>(
380                self.as_ptr(),
381                sel!(useHeaps:count:),
382                heaps,
383                count,
384            );
385        }
386    }
387}
388
389impl Clone for AccelerationStructureCommandEncoder {
390    fn clone(&self) -> Self {
391        unsafe {
392            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
393        }
394        Self(self.0)
395    }
396}
397
398impl Drop for AccelerationStructureCommandEncoder {
399    fn drop(&mut self) {
400        unsafe {
401            msg_send_0::<()>(self.as_ptr(), sel!(release));
402        }
403    }
404}
405
406impl Referencing for AccelerationStructureCommandEncoder {
407    #[inline]
408    fn as_ptr(&self) -> *const c_void {
409        self.0.as_ptr()
410    }
411}
412
413unsafe impl Send for AccelerationStructureCommandEncoder {}
414unsafe impl Sync for AccelerationStructureCommandEncoder {}
415
416impl std::fmt::Debug for AccelerationStructureCommandEncoder {
417    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        f.debug_struct("AccelerationStructureCommandEncoder")
419            .field("label", &self.label())
420            .finish()
421    }
422}