Skip to main content

mtl_gpu/device/
indirect.rs

1//! Indirect command buffer and related Device factory methods.
2//!
3//! Corresponds to indirect resource creation methods in `Metal/MTLDevice.hpp`.
4
5use std::ffi::c_void;
6
7use mtl_foundation::{Referencing, UInteger};
8use mtl_sys::{msg_send_1, msg_send_2, msg_send_3, sel};
9
10use super::Device;
11use crate::argument::{ArgumentEncoder, BufferBinding};
12use crate::counter::{CounterSampleBuffer, CounterSampleBufferDescriptor, CounterSet};
13use crate::enums::ResourceOptions;
14use crate::indirect::{IndirectCommandBuffer, IndirectCommandBufferDescriptor};
15use crate::log_state::{LogState, LogStateDescriptor};
16use crate::rasterization_rate::{RasterizationRateMap, RasterizationRateMapDescriptor};
17use crate::tensor::{Tensor, TensorDescriptor};
18use crate::texture_view_pool::{ResourceViewPoolDescriptor, TextureViewPool};
19
20impl Device {
21    // =========================================================================
22    // Indirect Command Buffer
23    // =========================================================================
24
25    /// Create a new indirect command buffer.
26    ///
27    /// C++ equivalent: `IndirectCommandBuffer* newIndirectCommandBuffer(const IndirectCommandBufferDescriptor*, NS::UInteger, ResourceOptions)`
28    pub fn new_indirect_command_buffer(
29        &self,
30        descriptor: &IndirectCommandBufferDescriptor,
31        max_count: UInteger,
32        options: ResourceOptions,
33    ) -> Option<IndirectCommandBuffer> {
34        unsafe {
35            let ptr: *mut c_void = msg_send_3(
36                self.as_ptr(),
37                sel!(newIndirectCommandBufferWithDescriptor:maxCommandCount:options:),
38                descriptor.as_ptr(),
39                max_count,
40                options,
41            );
42            if ptr.is_null() {
43                None
44            } else {
45                IndirectCommandBuffer::from_raw(ptr)
46            }
47        }
48    }
49
50    // =========================================================================
51    // Rasterization Rate Map
52    // =========================================================================
53
54    /// Create a new rasterization rate map.
55    ///
56    /// C++ equivalent: `RasterizationRateMap* newRasterizationRateMap(const RasterizationRateMapDescriptor*)`
57    pub fn new_rasterization_rate_map(
58        &self,
59        descriptor: &RasterizationRateMapDescriptor,
60    ) -> Option<RasterizationRateMap> {
61        unsafe {
62            let ptr: *mut c_void = msg_send_1(
63                self.as_ptr(),
64                sel!(newRasterizationRateMapWithDescriptor:),
65                descriptor.as_ptr(),
66            );
67            if ptr.is_null() {
68                None
69            } else {
70                RasterizationRateMap::from_raw(ptr)
71            }
72        }
73    }
74
75    // =========================================================================
76    // Counter Sample Buffer
77    // =========================================================================
78
79    /// Create a new counter sample buffer.
80    ///
81    /// C++ equivalent: `CounterSampleBuffer* newCounterSampleBuffer(const CounterSampleBufferDescriptor*, NS::Error**)`
82    pub fn new_counter_sample_buffer(
83        &self,
84        descriptor: &CounterSampleBufferDescriptor,
85    ) -> Result<CounterSampleBuffer, mtl_foundation::Error> {
86        unsafe {
87            let mut error: *mut c_void = std::ptr::null_mut();
88            let ptr: *mut c_void = msg_send_2(
89                self.as_ptr(),
90                sel!(newCounterSampleBufferWithDescriptor:error:),
91                descriptor.as_ptr(),
92                &mut error as *mut _,
93            );
94
95            if ptr.is_null() {
96                if !error.is_null() {
97                    return Err(
98                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
99                    );
100                }
101                return Err(mtl_foundation::Error::error(
102                    std::ptr::null_mut(),
103                    -1,
104                    std::ptr::null_mut(),
105                )
106                .expect("failed to create error"));
107            }
108
109            Ok(CounterSampleBuffer::from_raw(ptr).expect("failed to create counter sample buffer"))
110        }
111    }
112
113    /// Get the available counter sets for this device.
114    ///
115    /// Returns a raw pointer to an NSArray of CounterSet objects.
116    ///
117    /// C++ equivalent: `NS::Array* counterSets() const`
118    #[inline]
119    pub fn counter_sets_raw(&self) -> *mut c_void {
120        unsafe { mtl_sys::msg_send_0(self.as_ptr(), sel!(counterSets)) }
121    }
122
123    /// Get the number of counter sets available.
124    pub fn counter_set_count(&self) -> UInteger {
125        unsafe {
126            let array = self.counter_sets_raw();
127            if array.is_null() {
128                return 0;
129            }
130            mtl_sys::msg_send_0(array as *const c_void, sel!(count))
131        }
132    }
133
134    /// Get a counter set at the specified index.
135    pub fn counter_set_at_index(&self, index: UInteger) -> Option<CounterSet> {
136        unsafe {
137            let array = self.counter_sets_raw();
138            if array.is_null() {
139                return None;
140            }
141            let ptr: *mut c_void = msg_send_1(array as *const c_void, sel!(objectAtIndex:), index);
142            if ptr.is_null() {
143                return None;
144            }
145            mtl_sys::msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
146            CounterSet::from_raw(ptr)
147        }
148    }
149
150    // =========================================================================
151    // Argument Encoder
152    // =========================================================================
153
154    /// Create a new argument encoder from an array of arguments.
155    ///
156    /// C++ equivalent: `ArgumentEncoder* newArgumentEncoder(const NS::Array* arguments)`
157    ///
158    /// # Safety
159    ///
160    /// The arguments pointer must be a valid NSArray of Argument objects.
161    pub unsafe fn new_argument_encoder_with_arguments(
162        &self,
163        arguments: *const c_void,
164    ) -> Option<ArgumentEncoder> {
165        unsafe {
166            let ptr: *mut c_void = msg_send_1(
167                self.as_ptr(),
168                sel!(newArgumentEncoderWithArguments:),
169                arguments,
170            );
171            if ptr.is_null() {
172                None
173            } else {
174                ArgumentEncoder::from_raw(ptr)
175            }
176        }
177    }
178
179    /// Create a new argument encoder from a buffer binding.
180    ///
181    /// C++ equivalent: `ArgumentEncoder* newArgumentEncoder(const BufferBinding*)`
182    pub fn new_argument_encoder_with_buffer_binding(
183        &self,
184        buffer_binding: &BufferBinding,
185    ) -> Option<ArgumentEncoder> {
186        unsafe {
187            let ptr: *mut c_void = msg_send_1(
188                self.as_ptr(),
189                sel!(newArgumentEncoderWithBufferBinding:),
190                buffer_binding.as_ptr(),
191            );
192            if ptr.is_null() {
193                None
194            } else {
195                ArgumentEncoder::from_raw(ptr)
196            }
197        }
198    }
199
200    // =========================================================================
201    // Log State
202    // =========================================================================
203
204    /// Create a new log state.
205    ///
206    /// C++ equivalent: `LogState* newLogState(const LogStateDescriptor*, NS::Error**)`
207    pub fn new_log_state(
208        &self,
209        descriptor: &LogStateDescriptor,
210    ) -> Result<LogState, mtl_foundation::Error> {
211        unsafe {
212            let mut error: *mut c_void = std::ptr::null_mut();
213            let ptr: *mut c_void = msg_send_2(
214                self.as_ptr(),
215                sel!(newLogStateWithDescriptor:error:),
216                descriptor.as_ptr(),
217                &mut error as *mut _,
218            );
219
220            if ptr.is_null() {
221                if !error.is_null() {
222                    return Err(
223                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
224                    );
225                }
226                return Err(mtl_foundation::Error::error(
227                    std::ptr::null_mut(),
228                    -1,
229                    std::ptr::null_mut(),
230                )
231                .expect("failed to create error"));
232            }
233
234            Ok(LogState::from_raw(ptr).expect("failed to create log state"))
235        }
236    }
237
238    // =========================================================================
239    // Texture View Pool
240    // =========================================================================
241
242    /// Create a new texture view pool.
243    ///
244    /// C++ equivalent: `TextureViewPool* newTextureViewPool(const ResourceViewPoolDescriptor*, NS::Error**)`
245    pub fn new_texture_view_pool(
246        &self,
247        descriptor: &ResourceViewPoolDescriptor,
248    ) -> Result<TextureViewPool, mtl_foundation::Error> {
249        unsafe {
250            let mut error: *mut c_void = std::ptr::null_mut();
251            let ptr: *mut c_void = msg_send_2(
252                self.as_ptr(),
253                sel!(newTextureViewPoolWithDescriptor:error:),
254                descriptor.as_ptr(),
255                &mut error as *mut _,
256            );
257
258            if ptr.is_null() {
259                if !error.is_null() {
260                    return Err(
261                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
262                    );
263                }
264                return Err(mtl_foundation::Error::error(
265                    std::ptr::null_mut(),
266                    -1,
267                    std::ptr::null_mut(),
268                )
269                .expect("failed to create error"));
270            }
271
272            Ok(TextureViewPool::from_raw(ptr).expect("failed to create texture view pool"))
273        }
274    }
275
276    // =========================================================================
277    // Tensor
278    // =========================================================================
279
280    /// Create a new tensor.
281    ///
282    /// C++ equivalent: `Tensor* newTensor(const TensorDescriptor*, NS::Error**)`
283    pub fn new_tensor(
284        &self,
285        descriptor: &TensorDescriptor,
286    ) -> Result<Tensor, mtl_foundation::Error> {
287        unsafe {
288            let mut error: *mut c_void = std::ptr::null_mut();
289            let ptr: *mut c_void = msg_send_2(
290                self.as_ptr(),
291                sel!(newTensorWithDescriptor:error:),
292                descriptor.as_ptr(),
293                &mut error as *mut _,
294            );
295
296            if ptr.is_null() {
297                if !error.is_null() {
298                    return Err(
299                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
300                    );
301                }
302                return Err(mtl_foundation::Error::error(
303                    std::ptr::null_mut(),
304                    -1,
305                    std::ptr::null_mut(),
306                )
307                .expect("failed to create error"));
308            }
309
310            Ok(Tensor::from_raw(ptr).expect("failed to create tensor"))
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    #[test]
318    fn test_device_indirect_methods_exist() {
319        // This test just verifies compilation - actual tests require hardware
320    }
321}