Skip to main content

mtl_gpu/mtl4/
machine_learning.rs

1//! MTL4 Machine Learning implementation.
2//!
3//! Corresponds to `Metal/MTL4MachineLearningPipeline.hpp` and
4//! `Metal/MTL4MachineLearningCommandEncoder.hpp`.
5
6use std::ffi::c_void;
7use std::ptr::NonNull;
8
9use mtl_foundation::{Integer, Referencing, UInteger};
10use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
11
12use super::{ArgumentTable, FunctionDescriptor};
13use crate::{Device, Heap};
14
15// ============================================================
16// MachineLearningPipelineDescriptor
17// ============================================================
18
19/// Descriptor for creating a machine learning pipeline.
20///
21/// C++ equivalent: `MTL4::MachineLearningPipelineDescriptor`
22#[repr(transparent)]
23pub struct MachineLearningPipelineDescriptor(NonNull<c_void>);
24
25impl MachineLearningPipelineDescriptor {
26    /// Create a MachineLearningPipelineDescriptor from a raw pointer.
27    #[inline]
28    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
29        NonNull::new(ptr).map(Self)
30    }
31
32    /// Get the raw pointer.
33    #[inline]
34    pub fn as_raw(&self) -> *mut c_void {
35        self.0.as_ptr()
36    }
37
38    /// Create a new machine learning pipeline descriptor.
39    pub fn new() -> Option<Self> {
40        unsafe {
41            let class = mtl_sys::Class::get("MTL4MachineLearningPipelineDescriptor")?;
42            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
43            if ptr.is_null() {
44                return None;
45            }
46            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
47            Self::from_raw(ptr)
48        }
49    }
50
51    /// Get the label.
52    ///
53    /// C++ equivalent: `NS::String* label() const`
54    pub fn label(&self) -> Option<String> {
55        unsafe {
56            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
57            if ns_string.is_null() {
58                return None;
59            }
60            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
61            if c_str.is_null() {
62                return None;
63            }
64            Some(
65                std::ffi::CStr::from_ptr(c_str)
66                    .to_string_lossy()
67                    .into_owned(),
68            )
69        }
70    }
71
72    /// Set the label.
73    ///
74    /// C++ equivalent: `void setLabel(const NS::String*)`
75    pub fn set_label(&self, label: &str) {
76        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
77            unsafe {
78                let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
79            }
80        }
81    }
82
83    /// Get the machine learning function descriptor.
84    ///
85    /// C++ equivalent: `FunctionDescriptor* machineLearningFunctionDescriptor() const`
86    pub fn machine_learning_function_descriptor(&self) -> Option<FunctionDescriptor> {
87        unsafe {
88            let ptr: *mut c_void =
89                msg_send_0(self.as_ptr(), sel!(machineLearningFunctionDescriptor));
90            FunctionDescriptor::from_raw(ptr)
91        }
92    }
93
94    /// Set the machine learning function descriptor.
95    ///
96    /// C++ equivalent: `void setMachineLearningFunctionDescriptor(const MTL4::FunctionDescriptor*)`
97    pub fn set_machine_learning_function_descriptor(&self, descriptor: &FunctionDescriptor) {
98        unsafe {
99            let _: () = msg_send_1(
100                self.as_ptr(),
101                sel!(setMachineLearningFunctionDescriptor:),
102                descriptor.as_ptr(),
103            );
104        }
105    }
106
107    /// Get input dimensions at buffer index (as raw pointer to TensorExtents).
108    ///
109    /// C++ equivalent: `MTL::TensorExtents* inputDimensionsAtBufferIndex(NS::Integer)`
110    pub fn input_dimensions_at_buffer_index_raw(&self, buffer_index: Integer) -> *mut c_void {
111        unsafe {
112            msg_send_1(
113                self.as_ptr(),
114                sel!(inputDimensionsAtBufferIndex:),
115                buffer_index,
116            )
117        }
118    }
119
120    /// Set input dimensions at buffer index (from raw pointer to TensorExtents).
121    ///
122    /// C++ equivalent: `void setInputDimensions(const MTL::TensorExtents*, NS::Integer)`
123    pub fn set_input_dimensions_raw(&self, dimensions: *const c_void, buffer_index: Integer) {
124        unsafe {
125            let _: () = msg_send_2(
126                self.as_ptr(),
127                sel!(setInputDimensions:atBufferIndex:),
128                dimensions,
129                buffer_index,
130            );
131        }
132    }
133
134    /// Reset the descriptor to its default state.
135    ///
136    /// C++ equivalent: `void reset()`
137    pub fn reset(&self) {
138        unsafe {
139            let _: () = msg_send_0(self.as_ptr(), sel!(reset));
140        }
141    }
142}
143
144impl Default for MachineLearningPipelineDescriptor {
145    fn default() -> Self {
146        Self::new().expect("Failed to create MTL4MachineLearningPipelineDescriptor")
147    }
148}
149
150impl Clone for MachineLearningPipelineDescriptor {
151    fn clone(&self) -> Self {
152        unsafe {
153            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
154        }
155        Self(self.0)
156    }
157}
158
159impl Drop for MachineLearningPipelineDescriptor {
160    fn drop(&mut self) {
161        unsafe {
162            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
163        }
164    }
165}
166
167impl Referencing for MachineLearningPipelineDescriptor {
168    #[inline]
169    fn as_ptr(&self) -> *const c_void {
170        self.0.as_ptr()
171    }
172}
173
174unsafe impl Send for MachineLearningPipelineDescriptor {}
175unsafe impl Sync for MachineLearningPipelineDescriptor {}
176
177impl std::fmt::Debug for MachineLearningPipelineDescriptor {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("MachineLearningPipelineDescriptor")
180            .field("label", &self.label())
181            .finish()
182    }
183}
184
185// ============================================================
186// MachineLearningPipelineReflection
187// ============================================================
188
189/// Reflection data for a machine learning pipeline.
190///
191/// C++ equivalent: `MTL4::MachineLearningPipelineReflection`
192#[repr(transparent)]
193pub struct MachineLearningPipelineReflection(NonNull<c_void>);
194
195impl MachineLearningPipelineReflection {
196    /// Create a MachineLearningPipelineReflection from a raw pointer.
197    #[inline]
198    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
199        NonNull::new(ptr).map(Self)
200    }
201
202    /// Get the raw pointer.
203    #[inline]
204    pub fn as_raw(&self) -> *mut c_void {
205        self.0.as_ptr()
206    }
207
208    /// Get the bindings array (as raw pointer to NSArray).
209    ///
210    /// C++ equivalent: `NS::Array* bindings() const`
211    pub fn bindings_raw(&self) -> *mut c_void {
212        unsafe { msg_send_0(self.as_ptr(), sel!(bindings)) }
213    }
214}
215
216impl Clone for MachineLearningPipelineReflection {
217    fn clone(&self) -> Self {
218        unsafe {
219            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
220        }
221        Self(self.0)
222    }
223}
224
225impl Drop for MachineLearningPipelineReflection {
226    fn drop(&mut self) {
227        unsafe {
228            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
229        }
230    }
231}
232
233impl Referencing for MachineLearningPipelineReflection {
234    #[inline]
235    fn as_ptr(&self) -> *const c_void {
236        self.0.as_ptr()
237    }
238}
239
240unsafe impl Send for MachineLearningPipelineReflection {}
241unsafe impl Sync for MachineLearningPipelineReflection {}
242
243impl std::fmt::Debug for MachineLearningPipelineReflection {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        f.debug_struct("MachineLearningPipelineReflection").finish()
246    }
247}
248
249// ============================================================
250// MachineLearningPipelineState
251// ============================================================
252
253/// A compiled machine learning pipeline state.
254///
255/// C++ equivalent: `MTL4::MachineLearningPipelineState`
256#[repr(transparent)]
257pub struct MachineLearningPipelineState(NonNull<c_void>);
258
259impl MachineLearningPipelineState {
260    /// Create a MachineLearningPipelineState from a raw pointer.
261    #[inline]
262    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
263        NonNull::new(ptr).map(Self)
264    }
265
266    /// Get the raw pointer.
267    #[inline]
268    pub fn as_raw(&self) -> *mut c_void {
269        self.0.as_ptr()
270    }
271
272    /// Get the device.
273    ///
274    /// C++ equivalent: `MTL::Device* device() const`
275    pub fn device(&self) -> Option<Device> {
276        unsafe {
277            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
278            Device::from_raw(ptr)
279        }
280    }
281
282    /// Get the label.
283    ///
284    /// C++ equivalent: `NS::String* label() const`
285    pub fn label(&self) -> Option<String> {
286        unsafe {
287            let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
288            if ns_string.is_null() {
289                return None;
290            }
291            let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
292            if c_str.is_null() {
293                return None;
294            }
295            Some(
296                std::ffi::CStr::from_ptr(c_str)
297                    .to_string_lossy()
298                    .into_owned(),
299            )
300        }
301    }
302
303    /// Get the intermediates heap size.
304    ///
305    /// C++ equivalent: `NS::UInteger intermediatesHeapSize() const`
306    pub fn intermediates_heap_size(&self) -> UInteger {
307        unsafe { msg_send_0(self.as_ptr(), sel!(intermediatesHeapSize)) }
308    }
309
310    /// Get the pipeline reflection.
311    ///
312    /// C++ equivalent: `MachineLearningPipelineReflection* reflection() const`
313    pub fn reflection(&self) -> Option<MachineLearningPipelineReflection> {
314        unsafe {
315            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(reflection));
316            MachineLearningPipelineReflection::from_raw(ptr)
317        }
318    }
319}
320
321impl Clone for MachineLearningPipelineState {
322    fn clone(&self) -> Self {
323        unsafe {
324            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
325        }
326        Self(self.0)
327    }
328}
329
330impl Drop for MachineLearningPipelineState {
331    fn drop(&mut self) {
332        unsafe {
333            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
334        }
335    }
336}
337
338impl Referencing for MachineLearningPipelineState {
339    #[inline]
340    fn as_ptr(&self) -> *const c_void {
341        self.0.as_ptr()
342    }
343}
344
345unsafe impl Send for MachineLearningPipelineState {}
346unsafe impl Sync for MachineLearningPipelineState {}
347
348impl std::fmt::Debug for MachineLearningPipelineState {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        f.debug_struct("MachineLearningPipelineState")
351            .field("label", &self.label())
352            .field("intermediates_heap_size", &self.intermediates_heap_size())
353            .finish()
354    }
355}
356
357// ============================================================
358// MachineLearningCommandEncoder
359// ============================================================
360
361/// A command encoder for machine learning operations.
362///
363/// C++ equivalent: `MTL4::MachineLearningCommandEncoder`
364#[repr(transparent)]
365pub struct MachineLearningCommandEncoder(NonNull<c_void>);
366
367impl MachineLearningCommandEncoder {
368    /// Create a MachineLearningCommandEncoder from a raw pointer.
369    #[inline]
370    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
371        NonNull::new(ptr).map(Self)
372    }
373
374    /// Get the raw pointer.
375    #[inline]
376    pub fn as_raw(&self) -> *mut c_void {
377        self.0.as_ptr()
378    }
379
380    /// Set the pipeline state.
381    ///
382    /// C++ equivalent: `void setPipelineState(const MTL4::MachineLearningPipelineState*)`
383    pub fn set_pipeline_state(&self, pipeline_state: &MachineLearningPipelineState) {
384        unsafe {
385            let _: () = msg_send_1(
386                self.as_ptr(),
387                sel!(setPipelineState:),
388                pipeline_state.as_ptr(),
389            );
390        }
391    }
392
393    /// Set the argument table.
394    ///
395    /// C++ equivalent: `void setArgumentTable(const MTL4::ArgumentTable*)`
396    pub fn set_argument_table(&self, argument_table: &ArgumentTable) {
397        unsafe {
398            let _: () = msg_send_1(
399                self.as_ptr(),
400                sel!(setArgumentTable:),
401                argument_table.as_ptr(),
402            );
403        }
404    }
405
406    /// Dispatch the neural network.
407    ///
408    /// C++ equivalent: `void dispatchNetwork(const MTL::Heap*)`
409    pub fn dispatch_network(&self, intermediates_heap: &Heap) {
410        unsafe {
411            let _: () = msg_send_1(
412                self.as_ptr(),
413                sel!(dispatchNetworkWithIntermediatesHeap:),
414                intermediates_heap.as_ptr(),
415            );
416        }
417    }
418
419    /// End encoding.
420    ///
421    /// C++ equivalent: `void endEncoding()`
422    pub fn end_encoding(&self) {
423        unsafe {
424            let _: () = msg_send_0(self.as_ptr(), sel!(endEncoding));
425        }
426    }
427}
428
429impl Clone for MachineLearningCommandEncoder {
430    fn clone(&self) -> Self {
431        unsafe {
432            mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
433        }
434        Self(self.0)
435    }
436}
437
438impl Drop for MachineLearningCommandEncoder {
439    fn drop(&mut self) {
440        unsafe {
441            mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
442        }
443    }
444}
445
446impl Referencing for MachineLearningCommandEncoder {
447    #[inline]
448    fn as_ptr(&self) -> *const c_void {
449        self.0.as_ptr()
450    }
451}
452
453unsafe impl Send for MachineLearningCommandEncoder {}
454unsafe impl Sync for MachineLearningCommandEncoder {}
455
456impl std::fmt::Debug for MachineLearningCommandEncoder {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        f.debug_struct("MachineLearningCommandEncoder").finish()
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_machine_learning_pipeline_descriptor_size() {
468        assert_eq!(
469            std::mem::size_of::<MachineLearningPipelineDescriptor>(),
470            std::mem::size_of::<*mut c_void>()
471        );
472    }
473
474    #[test]
475    fn test_machine_learning_pipeline_reflection_size() {
476        assert_eq!(
477            std::mem::size_of::<MachineLearningPipelineReflection>(),
478            std::mem::size_of::<*mut c_void>()
479        );
480    }
481
482    #[test]
483    fn test_machine_learning_pipeline_state_size() {
484        assert_eq!(
485            std::mem::size_of::<MachineLearningPipelineState>(),
486            std::mem::size_of::<*mut c_void>()
487        );
488    }
489
490    #[test]
491    fn test_machine_learning_command_encoder_size() {
492        assert_eq!(
493            std::mem::size_of::<MachineLearningCommandEncoder>(),
494            std::mem::size_of::<*mut c_void>()
495        );
496    }
497}