Skip to main content

mtl_gpu/device/
mtl4.rs

1//! MTL4-related Device methods.
2//!
3//! Corresponds to Metal 4 factory methods in `Metal/MTLDevice.hpp`.
4
5use std::ffi::c_void;
6
7use mtl_foundation::Referencing;
8use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
9
10use super::Device;
11use crate::function_table::FunctionHandle;
12use crate::mtl4::{
13    Archive, ArgumentTable, ArgumentTableDescriptor, CommandAllocator, CommandAllocatorDescriptor,
14    CommandQueue, CommandQueueDescriptor, Compiler, CompilerDescriptor, CounterHeap,
15    CounterHeapDescriptor, CounterHeapType, PipelineDataSetSerializer,
16    PipelineDataSetSerializerDescriptor,
17};
18
19impl Device {
20    // =========================================================================
21    // Command Allocator
22    // =========================================================================
23
24    /// Create a new MTL4 command allocator with default settings.
25    ///
26    /// C++ equivalent: `MTL4::CommandAllocator* newCommandAllocator()`
27    pub fn new_command_allocator(&self) -> Option<CommandAllocator> {
28        unsafe {
29            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(newCommandAllocator));
30            if ptr.is_null() {
31                None
32            } else {
33                CommandAllocator::from_raw(ptr)
34            }
35        }
36    }
37
38    /// Create a new MTL4 command allocator with a descriptor.
39    ///
40    /// C++ equivalent: `MTL4::CommandAllocator* newCommandAllocator(const MTL4::CommandAllocatorDescriptor*, NS::Error**)`
41    pub fn new_command_allocator_with_descriptor(
42        &self,
43        descriptor: &CommandAllocatorDescriptor,
44    ) -> Result<CommandAllocator, mtl_foundation::Error> {
45        unsafe {
46            let mut error: *mut c_void = std::ptr::null_mut();
47            let ptr: *mut c_void = msg_send_2(
48                self.as_ptr(),
49                sel!(newCommandAllocatorWithDescriptor:error:),
50                descriptor.as_ptr(),
51                &mut error as *mut _,
52            );
53
54            if ptr.is_null() {
55                if !error.is_null() {
56                    return Err(
57                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
58                    );
59                }
60                return Err(mtl_foundation::Error::error(
61                    std::ptr::null_mut(),
62                    -1,
63                    std::ptr::null_mut(),
64                )
65                .expect("failed to create error"));
66            }
67
68            Ok(CommandAllocator::from_raw(ptr).expect("failed to create command allocator"))
69        }
70    }
71
72    // =========================================================================
73    // Command Queue
74    // =========================================================================
75
76    /// Create a new MTL4 command queue with default settings.
77    ///
78    /// C++ equivalent: `MTL4::CommandQueue* newMTL4CommandQueue()`
79    pub fn new_mtl4_command_queue(&self) -> Option<CommandQueue> {
80        unsafe {
81            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(newMTL4CommandQueue));
82            if ptr.is_null() {
83                None
84            } else {
85                CommandQueue::from_raw(ptr)
86            }
87        }
88    }
89
90    /// Create a new MTL4 command queue with a descriptor.
91    ///
92    /// C++ equivalent: `MTL4::CommandQueue* newMTL4CommandQueue(const MTL4::CommandQueueDescriptor*, NS::Error**)`
93    pub fn new_mtl4_command_queue_with_descriptor(
94        &self,
95        descriptor: &CommandQueueDescriptor,
96    ) -> Result<CommandQueue, mtl_foundation::Error> {
97        unsafe {
98            let mut error: *mut c_void = std::ptr::null_mut();
99            let ptr: *mut c_void = msg_send_2(
100                self.as_ptr(),
101                sel!(newMTL4CommandQueueWithDescriptor:error:),
102                descriptor.as_ptr(),
103                &mut error as *mut _,
104            );
105
106            if ptr.is_null() {
107                if !error.is_null() {
108                    return Err(
109                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
110                    );
111                }
112                return Err(mtl_foundation::Error::error(
113                    std::ptr::null_mut(),
114                    -1,
115                    std::ptr::null_mut(),
116                )
117                .expect("failed to create error"));
118            }
119
120            Ok(CommandQueue::from_raw(ptr).expect("failed to create command queue"))
121        }
122    }
123
124    // =========================================================================
125    // Argument Table
126    // =========================================================================
127
128    /// Create a new MTL4 argument table.
129    ///
130    /// C++ equivalent: `MTL4::ArgumentTable* newArgumentTable(const MTL4::ArgumentTableDescriptor*, NS::Error**)`
131    pub fn new_argument_table(
132        &self,
133        descriptor: &ArgumentTableDescriptor,
134    ) -> Result<ArgumentTable, mtl_foundation::Error> {
135        unsafe {
136            let mut error: *mut c_void = std::ptr::null_mut();
137            let ptr: *mut c_void = msg_send_2(
138                self.as_ptr(),
139                sel!(newArgumentTableWithDescriptor:error:),
140                descriptor.as_ptr(),
141                &mut error as *mut _,
142            );
143
144            if ptr.is_null() {
145                if !error.is_null() {
146                    return Err(
147                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
148                    );
149                }
150                return Err(mtl_foundation::Error::error(
151                    std::ptr::null_mut(),
152                    -1,
153                    std::ptr::null_mut(),
154                )
155                .expect("failed to create error"));
156            }
157
158            Ok(ArgumentTable::from_raw(ptr).expect("failed to create argument table"))
159        }
160    }
161
162    // =========================================================================
163    // Compiler
164    // =========================================================================
165
166    /// Create a new MTL4 compiler.
167    ///
168    /// C++ equivalent: `MTL4::Compiler* newCompiler(const MTL4::CompilerDescriptor*, NS::Error**)`
169    pub fn new_compiler(
170        &self,
171        descriptor: &CompilerDescriptor,
172    ) -> Result<Compiler, mtl_foundation::Error> {
173        unsafe {
174            let mut error: *mut c_void = std::ptr::null_mut();
175            let ptr: *mut c_void = msg_send_2(
176                self.as_ptr(),
177                sel!(newCompilerWithDescriptor:error:),
178                descriptor.as_ptr(),
179                &mut error as *mut _,
180            );
181
182            if ptr.is_null() {
183                if !error.is_null() {
184                    return Err(
185                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
186                    );
187                }
188                return Err(mtl_foundation::Error::error(
189                    std::ptr::null_mut(),
190                    -1,
191                    std::ptr::null_mut(),
192                )
193                .expect("failed to create error"));
194            }
195
196            Ok(Compiler::from_raw(ptr).expect("failed to create compiler"))
197        }
198    }
199
200    // =========================================================================
201    // Counter Heap
202    // =========================================================================
203
204    /// Create a new MTL4 counter heap.
205    ///
206    /// C++ equivalent: `MTL4::CounterHeap* newCounterHeap(const MTL4::CounterHeapDescriptor*, NS::Error**)`
207    pub fn new_counter_heap(
208        &self,
209        descriptor: &CounterHeapDescriptor,
210    ) -> Result<CounterHeap, mtl_foundation::Error> {
211        unsafe {
212            let mut error: *mut c_void = std::ptr::null_mut();
213            let ptr: *mut c_void = msg_send_2(
214                self.as_ptr(),
215                sel!(newCounterHeapWithDescriptor:error:),
216                descriptor.as_ptr(),
217                &mut error as *mut _,
218            );
219
220            if ptr.is_null() {
221                if !error.is_null() {
222                    return Err(
223                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
224                    );
225                }
226                return Err(mtl_foundation::Error::error(
227                    std::ptr::null_mut(),
228                    -1,
229                    std::ptr::null_mut(),
230                )
231                .expect("failed to create error"));
232            }
233
234            Ok(CounterHeap::from_raw(ptr).expect("failed to create counter heap"))
235        }
236    }
237
238    /// Get the size of a counter heap entry.
239    ///
240    /// C++ equivalent: `NS::UInteger sizeOfCounterHeapEntry(MTL4::CounterHeapType type)`
241    pub fn size_of_counter_heap_entry(
242        &self,
243        heap_type: CounterHeapType,
244    ) -> mtl_foundation::UInteger {
245        unsafe { msg_send_1(self.as_ptr(), sel!(sizeOfCounterHeapEntry:), heap_type) }
246    }
247
248    // =========================================================================
249    // Pipeline Data Set Serializer
250    // =========================================================================
251
252    /// Create a new MTL4 pipeline data set serializer.
253    ///
254    /// C++ equivalent: `MTL4::PipelineDataSetSerializer* newPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializerDescriptor*)`
255    pub fn new_pipeline_data_set_serializer(
256        &self,
257        descriptor: &PipelineDataSetSerializerDescriptor,
258    ) -> Option<PipelineDataSetSerializer> {
259        unsafe {
260            let ptr: *mut c_void = msg_send_1(
261                self.as_ptr(),
262                sel!(newPipelineDataSetSerializerWithDescriptor:),
263                descriptor.as_ptr(),
264            );
265            if ptr.is_null() {
266                None
267            } else {
268                PipelineDataSetSerializer::from_raw(ptr)
269            }
270        }
271    }
272
273    // =========================================================================
274    // Archive
275    // =========================================================================
276
277    /// Create a new MTL4 archive from a URL.
278    ///
279    /// C++ equivalent: `MTL4::Archive* newArchive(const NS::URL*, NS::Error**)`
280    ///
281    /// # Safety
282    ///
283    /// The url pointer must be a valid NS::URL object.
284    pub unsafe fn new_archive_with_url(
285        &self,
286        url: *const c_void,
287    ) -> Result<Archive, mtl_foundation::Error> {
288        unsafe {
289            let mut error: *mut c_void = std::ptr::null_mut();
290            let ptr: *mut c_void = msg_send_2(
291                self.as_ptr(),
292                sel!(newArchiveWithURL:error:),
293                url,
294                &mut error as *mut _,
295            );
296
297            if ptr.is_null() {
298                if !error.is_null() {
299                    return Err(
300                        mtl_foundation::Error::from_ptr(error).expect("error should be valid")
301                    );
302                }
303                return Err(mtl_foundation::Error::error(
304                    std::ptr::null_mut(),
305                    -1,
306                    std::ptr::null_mut(),
307                )
308                .expect("failed to create error"));
309            }
310
311            Ok(Archive::from_raw(ptr).expect("failed to create archive"))
312        }
313    }
314
315    // =========================================================================
316    // Function Handles
317    // =========================================================================
318
319    /// Get a function handle from a compiled function.
320    ///
321    /// C++ equivalent: `FunctionHandle* functionHandle(const MTL::Function*)`
322    pub fn function_handle(&self, function: &crate::Function) -> Option<FunctionHandle> {
323        unsafe {
324            let ptr: *mut c_void = msg_send_1(
325                self.as_ptr(),
326                sel!(functionHandleWithFunction:),
327                function.as_ptr(),
328            );
329            if ptr.is_null() {
330                None
331            } else {
332                let _: *mut c_void = msg_send_0(ptr, sel!(retain));
333                FunctionHandle::from_raw(ptr)
334            }
335        }
336    }
337
338    /// Get a function handle from a binary function.
339    ///
340    /// C++ equivalent: `FunctionHandle* functionHandle(const MTL4::BinaryFunction*)`
341    pub fn function_handle_with_binary_function(
342        &self,
343        function: &crate::mtl4::BinaryFunction,
344    ) -> Option<FunctionHandle> {
345        unsafe {
346            let ptr: *mut c_void = msg_send_1(
347                self.as_ptr(),
348                sel!(functionHandleWithFunction:),
349                function.as_ptr(),
350            );
351            if ptr.is_null() {
352                None
353            } else {
354                let _: *mut c_void = msg_send_0(ptr, sel!(retain));
355                FunctionHandle::from_raw(ptr)
356            }
357        }
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    #[test]
364    fn test_mtl4_device_methods_exist() {
365        // This test just verifies compilation - actual MTL4 tests require hardware support
366    }
367}