Skip to main content

mtl_gpu/argument/
encoder.rs

1//! An encoder for encoding resources into argument buffers.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_3, sel};
8
9use crate::Buffer;
10
11/// An encoder for encoding resources into argument buffers.
12///
13/// C++ equivalent: `MTL::ArgumentEncoder`
14#[repr(transparent)]
15pub struct ArgumentEncoder(pub(crate) NonNull<c_void>);
16
17impl ArgumentEncoder {
18    /// Create from a raw pointer.
19    ///
20    /// # Safety
21    ///
22    /// The pointer must be a valid Metal ArgumentEncoder.
23    #[inline]
24    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
25        NonNull::new(ptr).map(Self)
26    }
27
28    /// Get the raw pointer.
29    #[inline]
30    pub fn as_raw(&self) -> *mut c_void {
31        self.0.as_ptr()
32    }
33
34    /// Get the alignment of the encoded data.
35    ///
36    /// C++ equivalent: `NS::UInteger alignment() const`
37    #[inline]
38    pub fn alignment(&self) -> UInteger {
39        unsafe { msg_send_0(self.as_ptr(), sel!(alignment)) }
40    }
41
42    /// Get a pointer to the constant data at the given index.
43    ///
44    /// C++ equivalent: `void* constantData(NS::UInteger index)`
45    #[inline]
46    pub fn constant_data(&self, index: UInteger) -> *mut c_void {
47        unsafe { msg_send_1(self.as_ptr(), sel!(constantDataAtIndex:), index) }
48    }
49
50    /// Get the device that created this encoder.
51    ///
52    /// C++ equivalent: `Device* device() const`
53    pub fn device(&self) -> crate::Device {
54        unsafe {
55            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
56            // Device is retained by the encoder, retain it for our reference
57            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
58            crate::Device::from_raw(ptr).expect("device should be valid")
59        }
60    }
61
62    /// Get the encoded length in bytes.
63    ///
64    /// C++ equivalent: `NS::UInteger encodedLength() const`
65    #[inline]
66    pub fn encoded_length(&self) -> UInteger {
67        unsafe { msg_send_0(self.as_ptr(), sel!(encodedLength)) }
68    }
69
70    /// Get the label.
71    ///
72    /// C++ equivalent: `NS::String* label() const`
73    pub fn label(&self) -> Option<String> {
74        unsafe {
75            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
76            if ptr.is_null() {
77                return None;
78            }
79            let utf8_ptr: *const std::ffi::c_char =
80                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
81            if utf8_ptr.is_null() {
82                return None;
83            }
84            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
85            Some(c_str.to_string_lossy().into_owned())
86        }
87    }
88
89    /// Set the label.
90    ///
91    /// C++ equivalent: `void setLabel(const NS::String* label)`
92    pub fn set_label(&self, label: &str) {
93        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
94            unsafe {
95                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
96            }
97        }
98    }
99
100    /// Create a new argument encoder for a nested buffer at the given index.
101    ///
102    /// C++ equivalent: `ArgumentEncoder* newArgumentEncoder(NS::UInteger index)`
103    pub fn new_argument_encoder(&self, index: UInteger) -> Option<ArgumentEncoder> {
104        unsafe {
105            let ptr: *mut c_void = msg_send_1(
106                self.as_ptr(),
107                sel!(newArgumentEncoderForBufferAtIndex:),
108                index,
109            );
110            ArgumentEncoder::from_raw(ptr)
111        }
112    }
113
114    /// Set the argument buffer to encode into.
115    ///
116    /// C++ equivalent: `void setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger offset)`
117    pub fn set_argument_buffer(&self, buffer: &Buffer, offset: UInteger) {
118        unsafe {
119            let _: () = msg_send_2(
120                self.as_ptr(),
121                sel!(setArgumentBuffer:offset:),
122                buffer.as_ptr(),
123                offset,
124            );
125        }
126    }
127
128    /// Set the argument buffer with array element.
129    ///
130    /// C++ equivalent: `void setArgumentBuffer(const MTL::Buffer*, NS::UInteger, NS::UInteger)`
131    pub fn set_argument_buffer_with_array_element(
132        &self,
133        buffer: &Buffer,
134        start_offset: UInteger,
135        array_element: UInteger,
136    ) {
137        unsafe {
138            let _: () = msg_send_3(
139                self.as_ptr(),
140                sel!(setArgumentBuffer:startOffset:arrayElement:),
141                buffer.as_ptr(),
142                start_offset,
143                array_element,
144            );
145        }
146    }
147
148    /// Set a buffer at the given index.
149    ///
150    /// C++ equivalent: `void setBuffer(const MTL::Buffer*, NS::UInteger, NS::UInteger)`
151    pub fn set_buffer(&self, buffer: &Buffer, offset: UInteger, index: UInteger) {
152        unsafe {
153            let _: () = msg_send_3(
154                self.as_ptr(),
155                sel!(setBuffer:offset:atIndex:),
156                buffer.as_ptr(),
157                offset,
158                index,
159            );
160        }
161    }
162
163    /// Set a texture at the given index.
164    ///
165    /// C++ equivalent: `void setTexture(const MTL::Texture*, NS::UInteger)`
166    pub fn set_texture(&self, texture: &crate::Texture, index: UInteger) {
167        unsafe {
168            let _: () = msg_send_2(
169                self.as_ptr(),
170                sel!(setTexture:atIndex:),
171                texture.as_ptr(),
172                index,
173            );
174        }
175    }
176
177    /// Set a sampler state at the given index.
178    ///
179    /// C++ equivalent: `void setSamplerState(const MTL::SamplerState*, NS::UInteger)`
180    pub fn set_sampler_state(&self, sampler: &crate::SamplerState, index: UInteger) {
181        unsafe {
182            let _: () = msg_send_2(
183                self.as_ptr(),
184                sel!(setSamplerState:atIndex:),
185                sampler.as_ptr(),
186                index,
187            );
188        }
189    }
190
191    /// Set a render pipeline state at the given index.
192    ///
193    /// C++ equivalent: `void setRenderPipelineState(const MTL::RenderPipelineState*, NS::UInteger)`
194    pub fn set_render_pipeline_state(
195        &self,
196        pipeline: &crate::RenderPipelineState,
197        index: UInteger,
198    ) {
199        unsafe {
200            let _: () = msg_send_2(
201                self.as_ptr(),
202                sel!(setRenderPipelineState:atIndex:),
203                pipeline.as_ptr(),
204                index,
205            );
206        }
207    }
208
209    /// Set a compute pipeline state at the given index.
210    ///
211    /// C++ equivalent: `void setComputePipelineState(const MTL::ComputePipelineState*, NS::UInteger)`
212    pub fn set_compute_pipeline_state(
213        &self,
214        pipeline: &crate::ComputePipelineState,
215        index: UInteger,
216    ) {
217        unsafe {
218            let _: () = msg_send_2(
219                self.as_ptr(),
220                sel!(setComputePipelineState:atIndex:),
221                pipeline.as_ptr(),
222                index,
223            );
224        }
225    }
226
227    /// Set a depth stencil state at the given index.
228    ///
229    /// C++ equivalent: `void setDepthStencilState(const MTL::DepthStencilState*, NS::UInteger)`
230    pub fn set_depth_stencil_state(&self, state: &crate::DepthStencilState, index: UInteger) {
231        unsafe {
232            let _: () = msg_send_2(
233                self.as_ptr(),
234                sel!(setDepthStencilState:atIndex:),
235                state.as_ptr(),
236                index,
237            );
238        }
239    }
240
241    /// Set an acceleration structure at the given index.
242    ///
243    /// C++ equivalent: `void setAccelerationStructure(const MTL::AccelerationStructure*, NS::UInteger)`
244    pub fn set_acceleration_structure(
245        &self,
246        acceleration_structure: &crate::AccelerationStructure,
247        index: UInteger,
248    ) {
249        unsafe {
250            let _: () = msg_send_2(
251                self.as_ptr(),
252                sel!(setAccelerationStructure:atIndex:),
253                acceleration_structure.as_ptr(),
254                index,
255            );
256        }
257    }
258
259    /// Set an indirect command buffer at the given index.
260    ///
261    /// C++ equivalent: `void setIndirectCommandBuffer(const MTL::IndirectCommandBuffer*, NS::UInteger)`
262    pub fn set_indirect_command_buffer_ptr(&self, buffer: *const c_void, index: UInteger) {
263        unsafe {
264            let _: () = msg_send_2(
265                self.as_ptr(),
266                sel!(setIndirectCommandBuffer:atIndex:),
267                buffer,
268                index,
269            );
270        }
271    }
272
273    /// Set a visible function table at the given index.
274    ///
275    /// C++ equivalent: `void setVisibleFunctionTable(const MTL::VisibleFunctionTable*, NS::UInteger)`
276    pub fn set_visible_function_table_ptr(&self, table: *const c_void, index: UInteger) {
277        unsafe {
278            let _: () = msg_send_2(
279                self.as_ptr(),
280                sel!(setVisibleFunctionTable:atIndex:),
281                table,
282                index,
283            );
284        }
285    }
286
287    /// Set an intersection function table at the given index.
288    ///
289    /// C++ equivalent: `void setIntersectionFunctionTable(const MTL::IntersectionFunctionTable*, NS::UInteger)`
290    pub fn set_intersection_function_table_ptr(&self, table: *const c_void, index: UInteger) {
291        unsafe {
292            let _: () = msg_send_2(
293                self.as_ptr(),
294                sel!(setIntersectionFunctionTable:atIndex:),
295                table,
296                index,
297            );
298        }
299    }
300}
301
302impl Clone for ArgumentEncoder {
303    fn clone(&self) -> Self {
304        unsafe {
305            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
306        }
307        Self(self.0)
308    }
309}
310
311impl Drop for ArgumentEncoder {
312    fn drop(&mut self) {
313        unsafe {
314            msg_send_0::<()>(self.as_ptr(), sel!(release));
315        }
316    }
317}
318
319impl Referencing for ArgumentEncoder {
320    #[inline]
321    fn as_ptr(&self) -> *const c_void {
322        self.0.as_ptr()
323    }
324}
325
326unsafe impl Send for ArgumentEncoder {}
327unsafe impl Sync for ArgumentEncoder {}