mtl_gpu/mtl4/
machine_learning.rs1use std::ffi::c_void;
7use std::ptr::NonNull;
8
9use mtl_foundation::{Integer, Referencing, UInteger};
10use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
11
12use super::{ArgumentTable, FunctionDescriptor};
13use crate::{Device, Heap};
14
15#[repr(transparent)]
23pub struct MachineLearningPipelineDescriptor(NonNull<c_void>);
24
25impl MachineLearningPipelineDescriptor {
26 #[inline]
28 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
29 NonNull::new(ptr).map(Self)
30 }
31
32 #[inline]
34 pub fn as_raw(&self) -> *mut c_void {
35 self.0.as_ptr()
36 }
37
38 pub fn new() -> Option<Self> {
40 unsafe {
41 let class = mtl_sys::Class::get("MTL4MachineLearningPipelineDescriptor")?;
42 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
43 if ptr.is_null() {
44 return None;
45 }
46 let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
47 Self::from_raw(ptr)
48 }
49 }
50
51 pub fn label(&self) -> Option<String> {
55 unsafe {
56 let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
57 if ns_string.is_null() {
58 return None;
59 }
60 let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
61 if c_str.is_null() {
62 return None;
63 }
64 Some(
65 std::ffi::CStr::from_ptr(c_str)
66 .to_string_lossy()
67 .into_owned(),
68 )
69 }
70 }
71
72 pub fn set_label(&self, label: &str) {
76 if let Some(ns_label) = mtl_foundation::String::from_str(label) {
77 unsafe {
78 let _: () = msg_send_1(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
79 }
80 }
81 }
82
83 pub fn machine_learning_function_descriptor(&self) -> Option<FunctionDescriptor> {
87 unsafe {
88 let ptr: *mut c_void =
89 msg_send_0(self.as_ptr(), sel!(machineLearningFunctionDescriptor));
90 FunctionDescriptor::from_raw(ptr)
91 }
92 }
93
94 pub fn set_machine_learning_function_descriptor(&self, descriptor: &FunctionDescriptor) {
98 unsafe {
99 let _: () = msg_send_1(
100 self.as_ptr(),
101 sel!(setMachineLearningFunctionDescriptor:),
102 descriptor.as_ptr(),
103 );
104 }
105 }
106
107 pub fn input_dimensions_at_buffer_index_raw(&self, buffer_index: Integer) -> *mut c_void {
111 unsafe {
112 msg_send_1(
113 self.as_ptr(),
114 sel!(inputDimensionsAtBufferIndex:),
115 buffer_index,
116 )
117 }
118 }
119
120 pub fn set_input_dimensions_raw(&self, dimensions: *const c_void, buffer_index: Integer) {
124 unsafe {
125 let _: () = msg_send_2(
126 self.as_ptr(),
127 sel!(setInputDimensions:atBufferIndex:),
128 dimensions,
129 buffer_index,
130 );
131 }
132 }
133
134 pub fn reset(&self) {
138 unsafe {
139 let _: () = msg_send_0(self.as_ptr(), sel!(reset));
140 }
141 }
142}
143
144impl Default for MachineLearningPipelineDescriptor {
145 fn default() -> Self {
146 Self::new().expect("Failed to create MTL4MachineLearningPipelineDescriptor")
147 }
148}
149
150impl Clone for MachineLearningPipelineDescriptor {
151 fn clone(&self) -> Self {
152 unsafe {
153 mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
154 }
155 Self(self.0)
156 }
157}
158
159impl Drop for MachineLearningPipelineDescriptor {
160 fn drop(&mut self) {
161 unsafe {
162 mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
163 }
164 }
165}
166
167impl Referencing for MachineLearningPipelineDescriptor {
168 #[inline]
169 fn as_ptr(&self) -> *const c_void {
170 self.0.as_ptr()
171 }
172}
173
174unsafe impl Send for MachineLearningPipelineDescriptor {}
175unsafe impl Sync for MachineLearningPipelineDescriptor {}
176
177impl std::fmt::Debug for MachineLearningPipelineDescriptor {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("MachineLearningPipelineDescriptor")
180 .field("label", &self.label())
181 .finish()
182 }
183}
184
185#[repr(transparent)]
193pub struct MachineLearningPipelineReflection(NonNull<c_void>);
194
195impl MachineLearningPipelineReflection {
196 #[inline]
198 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
199 NonNull::new(ptr).map(Self)
200 }
201
202 #[inline]
204 pub fn as_raw(&self) -> *mut c_void {
205 self.0.as_ptr()
206 }
207
208 pub fn bindings_raw(&self) -> *mut c_void {
212 unsafe { msg_send_0(self.as_ptr(), sel!(bindings)) }
213 }
214}
215
216impl Clone for MachineLearningPipelineReflection {
217 fn clone(&self) -> Self {
218 unsafe {
219 mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
220 }
221 Self(self.0)
222 }
223}
224
225impl Drop for MachineLearningPipelineReflection {
226 fn drop(&mut self) {
227 unsafe {
228 mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
229 }
230 }
231}
232
233impl Referencing for MachineLearningPipelineReflection {
234 #[inline]
235 fn as_ptr(&self) -> *const c_void {
236 self.0.as_ptr()
237 }
238}
239
240unsafe impl Send for MachineLearningPipelineReflection {}
241unsafe impl Sync for MachineLearningPipelineReflection {}
242
243impl std::fmt::Debug for MachineLearningPipelineReflection {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("MachineLearningPipelineReflection").finish()
246 }
247}
248
249#[repr(transparent)]
257pub struct MachineLearningPipelineState(NonNull<c_void>);
258
259impl MachineLearningPipelineState {
260 #[inline]
262 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
263 NonNull::new(ptr).map(Self)
264 }
265
266 #[inline]
268 pub fn as_raw(&self) -> *mut c_void {
269 self.0.as_ptr()
270 }
271
272 pub fn device(&self) -> Option<Device> {
276 unsafe {
277 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
278 Device::from_raw(ptr)
279 }
280 }
281
282 pub fn label(&self) -> Option<String> {
286 unsafe {
287 let ns_string: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
288 if ns_string.is_null() {
289 return None;
290 }
291 let c_str: *const i8 = msg_send_0(ns_string, sel!(UTF8String));
292 if c_str.is_null() {
293 return None;
294 }
295 Some(
296 std::ffi::CStr::from_ptr(c_str)
297 .to_string_lossy()
298 .into_owned(),
299 )
300 }
301 }
302
303 pub fn intermediates_heap_size(&self) -> UInteger {
307 unsafe { msg_send_0(self.as_ptr(), sel!(intermediatesHeapSize)) }
308 }
309
310 pub fn reflection(&self) -> Option<MachineLearningPipelineReflection> {
314 unsafe {
315 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(reflection));
316 MachineLearningPipelineReflection::from_raw(ptr)
317 }
318 }
319}
320
321impl Clone for MachineLearningPipelineState {
322 fn clone(&self) -> Self {
323 unsafe {
324 mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
325 }
326 Self(self.0)
327 }
328}
329
330impl Drop for MachineLearningPipelineState {
331 fn drop(&mut self) {
332 unsafe {
333 mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
334 }
335 }
336}
337
338impl Referencing for MachineLearningPipelineState {
339 #[inline]
340 fn as_ptr(&self) -> *const c_void {
341 self.0.as_ptr()
342 }
343}
344
345unsafe impl Send for MachineLearningPipelineState {}
346unsafe impl Sync for MachineLearningPipelineState {}
347
348impl std::fmt::Debug for MachineLearningPipelineState {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 f.debug_struct("MachineLearningPipelineState")
351 .field("label", &self.label())
352 .field("intermediates_heap_size", &self.intermediates_heap_size())
353 .finish()
354 }
355}
356
357#[repr(transparent)]
365pub struct MachineLearningCommandEncoder(NonNull<c_void>);
366
367impl MachineLearningCommandEncoder {
368 #[inline]
370 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
371 NonNull::new(ptr).map(Self)
372 }
373
374 #[inline]
376 pub fn as_raw(&self) -> *mut c_void {
377 self.0.as_ptr()
378 }
379
380 pub fn set_pipeline_state(&self, pipeline_state: &MachineLearningPipelineState) {
384 unsafe {
385 let _: () = msg_send_1(
386 self.as_ptr(),
387 sel!(setPipelineState:),
388 pipeline_state.as_ptr(),
389 );
390 }
391 }
392
393 pub fn set_argument_table(&self, argument_table: &ArgumentTable) {
397 unsafe {
398 let _: () = msg_send_1(
399 self.as_ptr(),
400 sel!(setArgumentTable:),
401 argument_table.as_ptr(),
402 );
403 }
404 }
405
406 pub fn dispatch_network(&self, intermediates_heap: &Heap) {
410 unsafe {
411 let _: () = msg_send_1(
412 self.as_ptr(),
413 sel!(dispatchNetworkWithIntermediatesHeap:),
414 intermediates_heap.as_ptr(),
415 );
416 }
417 }
418
419 pub fn end_encoding(&self) {
423 unsafe {
424 let _: () = msg_send_0(self.as_ptr(), sel!(endEncoding));
425 }
426 }
427}
428
429impl Clone for MachineLearningCommandEncoder {
430 fn clone(&self) -> Self {
431 unsafe {
432 mtl_sys::msg_send_0::<*mut c_void>(self.as_ptr(), mtl_sys::sel!(retain));
433 }
434 Self(self.0)
435 }
436}
437
438impl Drop for MachineLearningCommandEncoder {
439 fn drop(&mut self) {
440 unsafe {
441 mtl_sys::msg_send_0::<()>(self.as_ptr(), mtl_sys::sel!(release));
442 }
443 }
444}
445
446impl Referencing for MachineLearningCommandEncoder {
447 #[inline]
448 fn as_ptr(&self) -> *const c_void {
449 self.0.as_ptr()
450 }
451}
452
453unsafe impl Send for MachineLearningCommandEncoder {}
454unsafe impl Sync for MachineLearningCommandEncoder {}
455
456impl std::fmt::Debug for MachineLearningCommandEncoder {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 f.debug_struct("MachineLearningCommandEncoder").finish()
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_machine_learning_pipeline_descriptor_size() {
468 assert_eq!(
469 std::mem::size_of::<MachineLearningPipelineDescriptor>(),
470 std::mem::size_of::<*mut c_void>()
471 );
472 }
473
474 #[test]
475 fn test_machine_learning_pipeline_reflection_size() {
476 assert_eq!(
477 std::mem::size_of::<MachineLearningPipelineReflection>(),
478 std::mem::size_of::<*mut c_void>()
479 );
480 }
481
482 #[test]
483 fn test_machine_learning_pipeline_state_size() {
484 assert_eq!(
485 std::mem::size_of::<MachineLearningPipelineState>(),
486 std::mem::size_of::<*mut c_void>()
487 );
488 }
489
490 #[test]
491 fn test_machine_learning_command_encoder_size() {
492 assert_eq!(
493 std::mem::size_of::<MachineLearningCommandEncoder>(),
494 std::mem::size_of::<*mut c_void>()
495 );
496 }
497}