1use std::ffi::c_void;
8use std::ptr::NonNull;
9
10use mtl_foundation::{Integer, Referencing, UInteger};
11use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_4, sel};
12
13use crate::Buffer;
14use crate::enums::{
15 CPUCacheMode, HazardTrackingMode, ResourceOptions, StorageMode, TensorDataType, TensorUsage,
16};
17use crate::types::ResourceID;
18
19#[repr(transparent)]
27pub struct TensorExtents(pub(crate) NonNull<c_void>);
28
29impl TensorExtents {
30 pub fn new() -> Option<Self> {
34 unsafe {
35 let class = mtl_sys::Class::get("MTLTensorExtents")?;
36 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
37 if ptr.is_null() {
38 return None;
39 }
40 let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
41 Self::from_raw(ptr)
42 }
43 }
44
45 pub fn with_values(values: &[Integer]) -> Option<Self> {
49 unsafe {
50 let class = mtl_sys::Class::get("MTLTensorExtents")?;
51 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
52 if ptr.is_null() {
53 return None;
54 }
55 let ptr: *mut c_void = msg_send_2(
56 ptr,
57 sel!(initWithRank:values:),
58 values.len() as UInteger,
59 values.as_ptr(),
60 );
61 Self::from_raw(ptr)
62 }
63 }
64
65 #[inline]
71 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
72 NonNull::new(ptr).map(Self)
73 }
74
75 #[inline]
77 pub fn as_raw(&self) -> *mut c_void {
78 self.0.as_ptr()
79 }
80
81 #[inline]
85 pub fn rank(&self) -> UInteger {
86 unsafe { msg_send_0(self.as_ptr(), sel!(rank)) }
87 }
88
89 #[inline]
93 pub fn extent_at_dimension_index(&self, dimension_index: UInteger) -> Integer {
94 unsafe {
95 msg_send_1(
96 self.as_ptr(),
97 sel!(extentAtDimensionIndex:),
98 dimension_index,
99 )
100 }
101 }
102}
103
104impl Clone for TensorExtents {
105 fn clone(&self) -> Self {
106 unsafe {
107 msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
108 }
109 Self(self.0)
110 }
111}
112
113impl Drop for TensorExtents {
114 fn drop(&mut self) {
115 unsafe {
116 msg_send_0::<()>(self.as_ptr(), sel!(release));
117 }
118 }
119}
120
121impl Referencing for TensorExtents {
122 #[inline]
123 fn as_ptr(&self) -> *const c_void {
124 self.0.as_ptr()
125 }
126}
127
128unsafe impl Send for TensorExtents {}
129unsafe impl Sync for TensorExtents {}
130
131#[repr(transparent)]
139pub struct TensorDescriptor(pub(crate) NonNull<c_void>);
140
141impl TensorDescriptor {
142 pub fn new() -> Option<Self> {
146 unsafe {
147 let class = mtl_sys::Class::get("MTLTensorDescriptor")?;
148 let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
149 if ptr.is_null() {
150 return None;
151 }
152 let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
153 Self::from_raw(ptr)
154 }
155 }
156
157 #[inline]
163 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
164 NonNull::new(ptr).map(Self)
165 }
166
167 #[inline]
169 pub fn as_raw(&self) -> *mut c_void {
170 self.0.as_ptr()
171 }
172
173 #[inline]
177 pub fn data_type(&self) -> TensorDataType {
178 unsafe { msg_send_0(self.as_ptr(), sel!(dataType)) }
179 }
180
181 pub fn set_data_type(&self, data_type: TensorDataType) {
185 unsafe {
186 let _: () = msg_send_1(self.as_ptr(), sel!(setDataType:), data_type);
187 }
188 }
189
190 pub fn dimensions(&self) -> Option<TensorExtents> {
194 unsafe {
195 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(dimensions));
196 if ptr.is_null() {
197 return None;
198 }
199 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
200 TensorExtents::from_raw(ptr)
201 }
202 }
203
204 pub fn set_dimensions(&self, dimensions: &TensorExtents) {
208 unsafe {
209 let _: () = msg_send_1(self.as_ptr(), sel!(setDimensions:), dimensions.as_ptr());
210 }
211 }
212
213 pub fn strides(&self) -> Option<TensorExtents> {
217 unsafe {
218 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(strides));
219 if ptr.is_null() {
220 return None;
221 }
222 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
223 TensorExtents::from_raw(ptr)
224 }
225 }
226
227 pub fn set_strides(&self, strides: &TensorExtents) {
231 unsafe {
232 let _: () = msg_send_1(self.as_ptr(), sel!(setStrides:), strides.as_ptr());
233 }
234 }
235
236 #[inline]
240 pub fn storage_mode(&self) -> StorageMode {
241 unsafe { msg_send_0(self.as_ptr(), sel!(storageMode)) }
242 }
243
244 pub fn set_storage_mode(&self, storage_mode: StorageMode) {
248 unsafe {
249 let _: () = msg_send_1(self.as_ptr(), sel!(setStorageMode:), storage_mode);
250 }
251 }
252
253 #[inline]
257 pub fn cpu_cache_mode(&self) -> CPUCacheMode {
258 unsafe { msg_send_0(self.as_ptr(), sel!(cpuCacheMode)) }
259 }
260
261 pub fn set_cpu_cache_mode(&self, cpu_cache_mode: CPUCacheMode) {
265 unsafe {
266 let _: () = msg_send_1(self.as_ptr(), sel!(setCpuCacheMode:), cpu_cache_mode);
267 }
268 }
269
270 #[inline]
274 pub fn hazard_tracking_mode(&self) -> HazardTrackingMode {
275 unsafe { msg_send_0(self.as_ptr(), sel!(hazardTrackingMode)) }
276 }
277
278 pub fn set_hazard_tracking_mode(&self, hazard_tracking_mode: HazardTrackingMode) {
282 unsafe {
283 let _: () = msg_send_1(
284 self.as_ptr(),
285 sel!(setHazardTrackingMode:),
286 hazard_tracking_mode,
287 );
288 }
289 }
290
291 #[inline]
295 pub fn resource_options(&self) -> ResourceOptions {
296 unsafe { msg_send_0(self.as_ptr(), sel!(resourceOptions)) }
297 }
298
299 pub fn set_resource_options(&self, resource_options: ResourceOptions) {
303 unsafe {
304 let _: () = msg_send_1(self.as_ptr(), sel!(setResourceOptions:), resource_options);
305 }
306 }
307
308 #[inline]
312 pub fn usage(&self) -> TensorUsage {
313 unsafe { msg_send_0(self.as_ptr(), sel!(usage)) }
314 }
315
316 pub fn set_usage(&self, usage: TensorUsage) {
320 unsafe {
321 let _: () = msg_send_1(self.as_ptr(), sel!(setUsage:), usage);
322 }
323 }
324}
325
326impl Default for TensorDescriptor {
327 fn default() -> Self {
328 Self::new().expect("failed to create TensorDescriptor")
329 }
330}
331
332impl Clone for TensorDescriptor {
333 fn clone(&self) -> Self {
334 unsafe {
335 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
336 Self::from_raw(ptr).expect("copy should succeed")
337 }
338 }
339}
340
341impl Drop for TensorDescriptor {
342 fn drop(&mut self) {
343 unsafe {
344 msg_send_0::<()>(self.as_ptr(), sel!(release));
345 }
346 }
347}
348
349impl Referencing for TensorDescriptor {
350 #[inline]
351 fn as_ptr(&self) -> *const c_void {
352 self.0.as_ptr()
353 }
354}
355
356unsafe impl Send for TensorDescriptor {}
357unsafe impl Sync for TensorDescriptor {}
358
359#[repr(transparent)]
367pub struct Tensor(pub(crate) NonNull<c_void>);
368
369impl Tensor {
370 #[inline]
376 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
377 NonNull::new(ptr).map(Self)
378 }
379
380 #[inline]
382 pub fn as_raw(&self) -> *mut c_void {
383 self.0.as_ptr()
384 }
385
386 pub fn buffer(&self) -> Option<Buffer> {
390 unsafe {
391 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(buffer));
392 if ptr.is_null() {
393 return None;
394 }
395 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
396 Buffer::from_raw(ptr)
397 }
398 }
399
400 #[inline]
404 pub fn buffer_offset(&self) -> UInteger {
405 unsafe { msg_send_0(self.as_ptr(), sel!(bufferOffset)) }
406 }
407
408 #[inline]
412 pub fn data_type(&self) -> TensorDataType {
413 unsafe { msg_send_0(self.as_ptr(), sel!(dataType)) }
414 }
415
416 pub fn dimensions(&self) -> Option<TensorExtents> {
420 unsafe {
421 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(dimensions));
422 if ptr.is_null() {
423 return None;
424 }
425 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
426 TensorExtents::from_raw(ptr)
427 }
428 }
429
430 pub fn strides(&self) -> Option<TensorExtents> {
434 unsafe {
435 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(strides));
436 if ptr.is_null() {
437 return None;
438 }
439 msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
440 TensorExtents::from_raw(ptr)
441 }
442 }
443
444 #[inline]
448 pub fn usage(&self) -> TensorUsage {
449 unsafe { msg_send_0(self.as_ptr(), sel!(usage)) }
450 }
451
452 #[inline]
456 pub fn gpu_resource_id(&self) -> ResourceID {
457 unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
458 }
459
460 pub fn get_bytes(
464 &self,
465 bytes: *mut c_void,
466 strides: &TensorExtents,
467 slice_origin: &TensorExtents,
468 slice_dimensions: &TensorExtents,
469 ) {
470 unsafe {
471 let _: () = msg_send_4(
472 self.as_ptr(),
473 sel!(getBytes:strides:fromSliceOrigin:sliceDimensions:),
474 bytes,
475 strides.as_ptr(),
476 slice_origin.as_ptr(),
477 slice_dimensions.as_ptr(),
478 );
479 }
480 }
481
482 pub fn replace_slice_origin(
486 &self,
487 slice_origin: &TensorExtents,
488 slice_dimensions: &TensorExtents,
489 bytes: *const c_void,
490 strides: &TensorExtents,
491 ) {
492 unsafe {
493 let _: () = msg_send_4(
494 self.as_ptr(),
495 sel!(replaceSliceOrigin:sliceDimensions:withBytes:strides:),
496 slice_origin.as_ptr(),
497 slice_dimensions.as_ptr(),
498 bytes,
499 strides.as_ptr(),
500 );
501 }
502 }
503}
504
505impl Clone for Tensor {
506 fn clone(&self) -> Self {
507 unsafe {
508 msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
509 }
510 Self(self.0)
511 }
512}
513
514impl Drop for Tensor {
515 fn drop(&mut self) {
516 unsafe {
517 msg_send_0::<()>(self.as_ptr(), sel!(release));
518 }
519 }
520}
521
522impl Referencing for Tensor {
523 #[inline]
524 fn as_ptr(&self) -> *const c_void {
525 self.0.as_ptr()
526 }
527}
528
529unsafe impl Send for Tensor {}
530unsafe impl Sync for Tensor {}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_tensor_extents_creation() {
538 let _extents = TensorExtents::new();
540 }
541
542 #[test]
543 fn test_tensor_extents_with_values() {
544 let values = [2, 3, 4];
545 let extents = TensorExtents::with_values(&values);
546 if let Some(e) = extents {
548 assert_eq!(e.rank(), 3);
549 assert_eq!(e.extent_at_dimension_index(0), 2);
550 assert_eq!(e.extent_at_dimension_index(1), 3);
551 assert_eq!(e.extent_at_dimension_index(2), 4);
552 }
553 }
554
555 #[test]
556 fn test_tensor_extents_size() {
557 assert_eq!(
558 std::mem::size_of::<TensorExtents>(),
559 std::mem::size_of::<*mut c_void>()
560 );
561 }
562
563 #[test]
564 fn test_tensor_descriptor_creation() {
565 let _descriptor = TensorDescriptor::new();
567 }
568
569 #[test]
570 fn test_tensor_descriptor_size() {
571 assert_eq!(
572 std::mem::size_of::<TensorDescriptor>(),
573 std::mem::size_of::<*mut c_void>()
574 );
575 }
576
577 #[test]
578 fn test_tensor_size() {
579 assert_eq!(
580 std::mem::size_of::<Tensor>(),
581 std::mem::size_of::<*mut c_void>()
582 );
583 }
584
585 #[test]
586 fn test_tensor_descriptor_data_type() {
587 if let Some(descriptor) = TensorDescriptor::new() {
589 descriptor.set_data_type(TensorDataType::FLOAT32);
590 assert_eq!(descriptor.data_type(), TensorDataType::FLOAT32);
591 }
592 }
593
594 #[test]
595 fn test_tensor_descriptor_usage() {
596 if let Some(descriptor) = TensorDescriptor::new() {
598 descriptor.set_usage(TensorUsage::COMPUTE | TensorUsage::MACHINE_LEARNING);
599 let usage = descriptor.usage();
600 assert!((usage.0 & TensorUsage::COMPUTE.0) != 0);
601 assert!((usage.0 & TensorUsage::MACHINE_LEARNING.0) != 0);
602 }
603 }
604}