Skip to main content

mtl_gpu/encoder/compute_encoder/
acceleration.rs

1//! Acceleration structure and function table methods for ComputeCommandEncoder.
2
3use std::ffi::c_void;
4
5use mtl_foundation::{Referencing, UInteger};
6use mtl_sys::sel;
7
8use super::ComputeCommandEncoder;
9
10impl ComputeCommandEncoder {
11    // =========================================================================
12    // Acceleration Structure Bindings
13    // =========================================================================
14
15    /// Set an acceleration structure at a buffer index (raw pointer version).
16    ///
17    /// C++ equivalent: `void setAccelerationStructure(const AccelerationStructure*, NS::UInteger)`
18    ///
19    /// # Safety
20    ///
21    /// The acceleration structure pointer must be a valid Metal acceleration structure object.
22    #[inline]
23    pub unsafe fn set_acceleration_structure_ptr(
24        &self,
25        acceleration_structure: *const c_void,
26        buffer_index: UInteger,
27    ) {
28        unsafe {
29            mtl_sys::msg_send_2::<(), *const c_void, UInteger>(
30                self.as_ptr(),
31                sel!(setAccelerationStructure: atBufferIndex:),
32                acceleration_structure,
33                buffer_index,
34            );
35        }
36    }
37
38    /// Set an acceleration structure at a buffer index.
39    ///
40    /// C++ equivalent: `void setAccelerationStructure(const AccelerationStructure*, NS::UInteger)`
41    #[inline]
42    pub fn set_acceleration_structure(
43        &self,
44        acceleration_structure: &crate::AccelerationStructure,
45        buffer_index: UInteger,
46    ) {
47        unsafe {
48            self.set_acceleration_structure_ptr(acceleration_structure.as_ptr(), buffer_index)
49        };
50    }
51
52    // =========================================================================
53    // Function Tables (Ray Tracing)
54    // =========================================================================
55
56    /// Set a visible function table at a buffer index (raw pointer version).
57    ///
58    /// C++ equivalent: `void setVisibleFunctionTable(const VisibleFunctionTable*, NS::UInteger)`
59    ///
60    /// # Safety
61    ///
62    /// The visible function table pointer must be valid.
63    #[inline]
64    pub unsafe fn set_visible_function_table_ptr(
65        &self,
66        visible_function_table: *const c_void,
67        buffer_index: UInteger,
68    ) {
69        unsafe {
70            mtl_sys::msg_send_2::<(), *const c_void, UInteger>(
71                self.as_ptr(),
72                sel!(setVisibleFunctionTable: atBufferIndex:),
73                visible_function_table,
74                buffer_index,
75            );
76        }
77    }
78
79    /// Set an intersection function table at a buffer index (raw pointer version).
80    ///
81    /// C++ equivalent: `void setIntersectionFunctionTable(const IntersectionFunctionTable*, NS::UInteger)`
82    ///
83    /// # Safety
84    ///
85    /// The intersection function table pointer must be valid.
86    #[inline]
87    pub unsafe fn set_intersection_function_table_ptr(
88        &self,
89        intersection_function_table: *const c_void,
90        buffer_index: UInteger,
91    ) {
92        unsafe {
93            mtl_sys::msg_send_2::<(), *const c_void, UInteger>(
94                self.as_ptr(),
95                sel!(setIntersectionFunctionTable: atBufferIndex:),
96                intersection_function_table,
97                buffer_index,
98            );
99        }
100    }
101
102    /// Set multiple visible function tables at a range of buffer indices (raw pointer version).
103    ///
104    /// C++ equivalent: `void setVisibleFunctionTables(const VisibleFunctionTable* const*, NS::Range)`
105    ///
106    /// # Safety
107    ///
108    /// The visible function tables pointer must be a valid array with at least `range.length` elements.
109    #[inline]
110    pub unsafe fn set_visible_function_tables_ptr(
111        &self,
112        visible_function_tables: *const *const c_void,
113        range_location: UInteger,
114        range_length: UInteger,
115    ) {
116        let range = mtl_foundation::Range::new(range_location, range_length);
117        unsafe {
118            mtl_sys::msg_send_2::<(), *const *const c_void, mtl_foundation::Range>(
119                self.as_ptr(),
120                sel!(setVisibleFunctionTables: withBufferRange:),
121                visible_function_tables,
122                range,
123            );
124        }
125    }
126
127    /// Set multiple intersection function tables at a range of buffer indices (raw pointer version).
128    ///
129    /// C++ equivalent: `void setIntersectionFunctionTables(const IntersectionFunctionTable* const*, NS::Range)`
130    ///
131    /// # Safety
132    ///
133    /// The intersection function tables pointer must be a valid array with at least `range.length` elements.
134    #[inline]
135    pub unsafe fn set_intersection_function_tables_ptr(
136        &self,
137        intersection_function_tables: *const *const c_void,
138        range_location: UInteger,
139        range_length: UInteger,
140    ) {
141        let range = mtl_foundation::Range::new(range_location, range_length);
142        unsafe {
143            mtl_sys::msg_send_2::<(), *const *const c_void, mtl_foundation::Range>(
144                self.as_ptr(),
145                sel!(setIntersectionFunctionTables: withBufferRange:),
146                intersection_function_tables,
147                range,
148            );
149        }
150    }
151}