Skip to main content

mtl_gpu/
tensor.rs

1//! Metal tensor types for ML operations.
2//!
3//! Corresponds to `Metal/MTLTensor.hpp`.
4//!
5//! Tensors represent multi-dimensional arrays for machine learning operations.
6
7use std::ffi::c_void;
8use std::ptr::NonNull;
9
10use mtl_foundation::{Integer, Referencing, UInteger};
11use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, msg_send_4, sel};
12
13use crate::Buffer;
14use crate::enums::{
15    CPUCacheMode, HazardTrackingMode, ResourceOptions, StorageMode, TensorDataType, TensorUsage,
16};
17use crate::types::ResourceID;
18
19// ============================================================================
20// TensorExtents
21// ============================================================================
22
23/// Extents (dimensions/strides) for a tensor.
24///
25/// C++ equivalent: `MTL::TensorExtents`
26#[repr(transparent)]
27pub struct TensorExtents(pub(crate) NonNull<c_void>);
28
29impl TensorExtents {
30    /// Create a new tensor extents.
31    ///
32    /// C++ equivalent: `TensorExtents* alloc()->init()`
33    pub fn new() -> Option<Self> {
34        unsafe {
35            let class = mtl_sys::Class::get("MTLTensorExtents")?;
36            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
37            if ptr.is_null() {
38                return None;
39            }
40            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
41            Self::from_raw(ptr)
42        }
43    }
44
45    /// Create a tensor extents with specific rank and values.
46    ///
47    /// C++ equivalent: `TensorExtents* alloc()->init(NS::UInteger, const NS::Integer*)`
48    pub fn with_values(values: &[Integer]) -> Option<Self> {
49        unsafe {
50            let class = mtl_sys::Class::get("MTLTensorExtents")?;
51            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
52            if ptr.is_null() {
53                return None;
54            }
55            let ptr: *mut c_void = msg_send_2(
56                ptr,
57                sel!(initWithRank:values:),
58                values.len() as UInteger,
59                values.as_ptr(),
60            );
61            Self::from_raw(ptr)
62        }
63    }
64
65    /// Create from a raw pointer.
66    ///
67    /// # Safety
68    ///
69    /// The pointer must be a valid Metal TensorExtents.
70    #[inline]
71    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
72        NonNull::new(ptr).map(Self)
73    }
74
75    /// Get the raw pointer.
76    #[inline]
77    pub fn as_raw(&self) -> *mut c_void {
78        self.0.as_ptr()
79    }
80
81    /// Get the rank (number of dimensions).
82    ///
83    /// C++ equivalent: `NS::UInteger rank() const`
84    #[inline]
85    pub fn rank(&self) -> UInteger {
86        unsafe { msg_send_0(self.as_ptr(), sel!(rank)) }
87    }
88
89    /// Get the extent at a specific dimension index.
90    ///
91    /// C++ equivalent: `NS::Integer extentAtDimensionIndex(NS::UInteger)`
92    #[inline]
93    pub fn extent_at_dimension_index(&self, dimension_index: UInteger) -> Integer {
94        unsafe {
95            msg_send_1(
96                self.as_ptr(),
97                sel!(extentAtDimensionIndex:),
98                dimension_index,
99            )
100        }
101    }
102}
103
104impl Clone for TensorExtents {
105    fn clone(&self) -> Self {
106        unsafe {
107            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
108        }
109        Self(self.0)
110    }
111}
112
113impl Drop for TensorExtents {
114    fn drop(&mut self) {
115        unsafe {
116            msg_send_0::<()>(self.as_ptr(), sel!(release));
117        }
118    }
119}
120
121impl Referencing for TensorExtents {
122    #[inline]
123    fn as_ptr(&self) -> *const c_void {
124        self.0.as_ptr()
125    }
126}
127
128unsafe impl Send for TensorExtents {}
129unsafe impl Sync for TensorExtents {}
130
131// ============================================================================
132// TensorDescriptor
133// ============================================================================
134
135/// Descriptor for creating a tensor.
136///
137/// C++ equivalent: `MTL::TensorDescriptor`
138#[repr(transparent)]
139pub struct TensorDescriptor(pub(crate) NonNull<c_void>);
140
141impl TensorDescriptor {
142    /// Create a new tensor descriptor.
143    ///
144    /// C++ equivalent: `TensorDescriptor* alloc()->init()`
145    pub fn new() -> Option<Self> {
146        unsafe {
147            let class = mtl_sys::Class::get("MTLTensorDescriptor")?;
148            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
149            if ptr.is_null() {
150                return None;
151            }
152            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
153            Self::from_raw(ptr)
154        }
155    }
156
157    /// Create from a raw pointer.
158    ///
159    /// # Safety
160    ///
161    /// The pointer must be a valid Metal TensorDescriptor.
162    #[inline]
163    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
164        NonNull::new(ptr).map(Self)
165    }
166
167    /// Get the raw pointer.
168    #[inline]
169    pub fn as_raw(&self) -> *mut c_void {
170        self.0.as_ptr()
171    }
172
173    /// Get the data type.
174    ///
175    /// C++ equivalent: `TensorDataType dataType() const`
176    #[inline]
177    pub fn data_type(&self) -> TensorDataType {
178        unsafe { msg_send_0(self.as_ptr(), sel!(dataType)) }
179    }
180
181    /// Set the data type.
182    ///
183    /// C++ equivalent: `void setDataType(MTL::TensorDataType)`
184    pub fn set_data_type(&self, data_type: TensorDataType) {
185        unsafe {
186            let _: () = msg_send_1(self.as_ptr(), sel!(setDataType:), data_type);
187        }
188    }
189
190    /// Get the dimensions.
191    ///
192    /// C++ equivalent: `TensorExtents* dimensions() const`
193    pub fn dimensions(&self) -> Option<TensorExtents> {
194        unsafe {
195            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(dimensions));
196            if ptr.is_null() {
197                return None;
198            }
199            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
200            TensorExtents::from_raw(ptr)
201        }
202    }
203
204    /// Set the dimensions.
205    ///
206    /// C++ equivalent: `void setDimensions(const MTL::TensorExtents*)`
207    pub fn set_dimensions(&self, dimensions: &TensorExtents) {
208        unsafe {
209            let _: () = msg_send_1(self.as_ptr(), sel!(setDimensions:), dimensions.as_ptr());
210        }
211    }
212
213    /// Get the strides.
214    ///
215    /// C++ equivalent: `TensorExtents* strides() const`
216    pub fn strides(&self) -> Option<TensorExtents> {
217        unsafe {
218            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(strides));
219            if ptr.is_null() {
220                return None;
221            }
222            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
223            TensorExtents::from_raw(ptr)
224        }
225    }
226
227    /// Set the strides.
228    ///
229    /// C++ equivalent: `void setStrides(const MTL::TensorExtents*)`
230    pub fn set_strides(&self, strides: &TensorExtents) {
231        unsafe {
232            let _: () = msg_send_1(self.as_ptr(), sel!(setStrides:), strides.as_ptr());
233        }
234    }
235
236    /// Get the storage mode.
237    ///
238    /// C++ equivalent: `StorageMode storageMode() const`
239    #[inline]
240    pub fn storage_mode(&self) -> StorageMode {
241        unsafe { msg_send_0(self.as_ptr(), sel!(storageMode)) }
242    }
243
244    /// Set the storage mode.
245    ///
246    /// C++ equivalent: `void setStorageMode(MTL::StorageMode)`
247    pub fn set_storage_mode(&self, storage_mode: StorageMode) {
248        unsafe {
249            let _: () = msg_send_1(self.as_ptr(), sel!(setStorageMode:), storage_mode);
250        }
251    }
252
253    /// Get the CPU cache mode.
254    ///
255    /// C++ equivalent: `CPUCacheMode cpuCacheMode() const`
256    #[inline]
257    pub fn cpu_cache_mode(&self) -> CPUCacheMode {
258        unsafe { msg_send_0(self.as_ptr(), sel!(cpuCacheMode)) }
259    }
260
261    /// Set the CPU cache mode.
262    ///
263    /// C++ equivalent: `void setCpuCacheMode(MTL::CPUCacheMode)`
264    pub fn set_cpu_cache_mode(&self, cpu_cache_mode: CPUCacheMode) {
265        unsafe {
266            let _: () = msg_send_1(self.as_ptr(), sel!(setCpuCacheMode:), cpu_cache_mode);
267        }
268    }
269
270    /// Get the hazard tracking mode.
271    ///
272    /// C++ equivalent: `HazardTrackingMode hazardTrackingMode() const`
273    #[inline]
274    pub fn hazard_tracking_mode(&self) -> HazardTrackingMode {
275        unsafe { msg_send_0(self.as_ptr(), sel!(hazardTrackingMode)) }
276    }
277
278    /// Set the hazard tracking mode.
279    ///
280    /// C++ equivalent: `void setHazardTrackingMode(MTL::HazardTrackingMode)`
281    pub fn set_hazard_tracking_mode(&self, hazard_tracking_mode: HazardTrackingMode) {
282        unsafe {
283            let _: () = msg_send_1(
284                self.as_ptr(),
285                sel!(setHazardTrackingMode:),
286                hazard_tracking_mode,
287            );
288        }
289    }
290
291    /// Get the resource options.
292    ///
293    /// C++ equivalent: `ResourceOptions resourceOptions() const`
294    #[inline]
295    pub fn resource_options(&self) -> ResourceOptions {
296        unsafe { msg_send_0(self.as_ptr(), sel!(resourceOptions)) }
297    }
298
299    /// Set the resource options.
300    ///
301    /// C++ equivalent: `void setResourceOptions(MTL::ResourceOptions)`
302    pub fn set_resource_options(&self, resource_options: ResourceOptions) {
303        unsafe {
304            let _: () = msg_send_1(self.as_ptr(), sel!(setResourceOptions:), resource_options);
305        }
306    }
307
308    /// Get the usage.
309    ///
310    /// C++ equivalent: `TensorUsage usage() const`
311    #[inline]
312    pub fn usage(&self) -> TensorUsage {
313        unsafe { msg_send_0(self.as_ptr(), sel!(usage)) }
314    }
315
316    /// Set the usage.
317    ///
318    /// C++ equivalent: `void setUsage(MTL::TensorUsage)`
319    pub fn set_usage(&self, usage: TensorUsage) {
320        unsafe {
321            let _: () = msg_send_1(self.as_ptr(), sel!(setUsage:), usage);
322        }
323    }
324}
325
326impl Default for TensorDescriptor {
327    fn default() -> Self {
328        Self::new().expect("failed to create TensorDescriptor")
329    }
330}
331
332impl Clone for TensorDescriptor {
333    fn clone(&self) -> Self {
334        unsafe {
335            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
336            Self::from_raw(ptr).expect("copy should succeed")
337        }
338    }
339}
340
341impl Drop for TensorDescriptor {
342    fn drop(&mut self) {
343        unsafe {
344            msg_send_0::<()>(self.as_ptr(), sel!(release));
345        }
346    }
347}
348
349impl Referencing for TensorDescriptor {
350    #[inline]
351    fn as_ptr(&self) -> *const c_void {
352        self.0.as_ptr()
353    }
354}
355
356unsafe impl Send for TensorDescriptor {}
357unsafe impl Sync for TensorDescriptor {}
358
359// ============================================================================
360// Tensor
361// ============================================================================
362
363/// A multi-dimensional array for machine learning.
364///
365/// C++ equivalent: `MTL::Tensor`
366#[repr(transparent)]
367pub struct Tensor(pub(crate) NonNull<c_void>);
368
369impl Tensor {
370    /// Create from a raw pointer.
371    ///
372    /// # Safety
373    ///
374    /// The pointer must be a valid Metal Tensor.
375    #[inline]
376    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
377        NonNull::new(ptr).map(Self)
378    }
379
380    /// Get the raw pointer.
381    #[inline]
382    pub fn as_raw(&self) -> *mut c_void {
383        self.0.as_ptr()
384    }
385
386    /// Get the backing buffer.
387    ///
388    /// C++ equivalent: `Buffer* buffer() const`
389    pub fn buffer(&self) -> Option<Buffer> {
390        unsafe {
391            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(buffer));
392            if ptr.is_null() {
393                return None;
394            }
395            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
396            Buffer::from_raw(ptr)
397        }
398    }
399
400    /// Get the buffer offset.
401    ///
402    /// C++ equivalent: `NS::UInteger bufferOffset() const`
403    #[inline]
404    pub fn buffer_offset(&self) -> UInteger {
405        unsafe { msg_send_0(self.as_ptr(), sel!(bufferOffset)) }
406    }
407
408    /// Get the data type.
409    ///
410    /// C++ equivalent: `TensorDataType dataType() const`
411    #[inline]
412    pub fn data_type(&self) -> TensorDataType {
413        unsafe { msg_send_0(self.as_ptr(), sel!(dataType)) }
414    }
415
416    /// Get the dimensions.
417    ///
418    /// C++ equivalent: `TensorExtents* dimensions() const`
419    pub fn dimensions(&self) -> Option<TensorExtents> {
420        unsafe {
421            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(dimensions));
422            if ptr.is_null() {
423                return None;
424            }
425            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
426            TensorExtents::from_raw(ptr)
427        }
428    }
429
430    /// Get the strides.
431    ///
432    /// C++ equivalent: `TensorExtents* strides() const`
433    pub fn strides(&self) -> Option<TensorExtents> {
434        unsafe {
435            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(strides));
436            if ptr.is_null() {
437                return None;
438            }
439            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
440            TensorExtents::from_raw(ptr)
441        }
442    }
443
444    /// Get the usage.
445    ///
446    /// C++ equivalent: `TensorUsage usage() const`
447    #[inline]
448    pub fn usage(&self) -> TensorUsage {
449        unsafe { msg_send_0(self.as_ptr(), sel!(usage)) }
450    }
451
452    /// Get the GPU resource ID.
453    ///
454    /// C++ equivalent: `ResourceID gpuResourceID() const`
455    #[inline]
456    pub fn gpu_resource_id(&self) -> ResourceID {
457        unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
458    }
459
460    /// Get bytes from the tensor.
461    ///
462    /// C++ equivalent: `void getBytes(void*, const TensorExtents*, const TensorExtents*, const TensorExtents*)`
463    pub fn get_bytes(
464        &self,
465        bytes: *mut c_void,
466        strides: &TensorExtents,
467        slice_origin: &TensorExtents,
468        slice_dimensions: &TensorExtents,
469    ) {
470        unsafe {
471            let _: () = msg_send_4(
472                self.as_ptr(),
473                sel!(getBytes:strides:fromSliceOrigin:sliceDimensions:),
474                bytes,
475                strides.as_ptr(),
476                slice_origin.as_ptr(),
477                slice_dimensions.as_ptr(),
478            );
479        }
480    }
481
482    /// Replace a slice in the tensor.
483    ///
484    /// C++ equivalent: `void replaceSliceOrigin(const TensorExtents*, const TensorExtents*, const void*, const TensorExtents*)`
485    pub fn replace_slice_origin(
486        &self,
487        slice_origin: &TensorExtents,
488        slice_dimensions: &TensorExtents,
489        bytes: *const c_void,
490        strides: &TensorExtents,
491    ) {
492        unsafe {
493            let _: () = msg_send_4(
494                self.as_ptr(),
495                sel!(replaceSliceOrigin:sliceDimensions:withBytes:strides:),
496                slice_origin.as_ptr(),
497                slice_dimensions.as_ptr(),
498                bytes,
499                strides.as_ptr(),
500            );
501        }
502    }
503}
504
505impl Clone for Tensor {
506    fn clone(&self) -> Self {
507        unsafe {
508            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
509        }
510        Self(self.0)
511    }
512}
513
514impl Drop for Tensor {
515    fn drop(&mut self) {
516        unsafe {
517            msg_send_0::<()>(self.as_ptr(), sel!(release));
518        }
519    }
520}
521
522impl Referencing for Tensor {
523    #[inline]
524    fn as_ptr(&self) -> *const c_void {
525        self.0.as_ptr()
526    }
527}
528
529unsafe impl Send for Tensor {}
530unsafe impl Sync for Tensor {}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_tensor_extents_creation() {
538        // TensorExtents may not be available on all systems
539        let _extents = TensorExtents::new();
540    }
541
542    #[test]
543    fn test_tensor_extents_with_values() {
544        let values = [2, 3, 4];
545        let extents = TensorExtents::with_values(&values);
546        // May not be available on all systems
547        if let Some(e) = extents {
548            assert_eq!(e.rank(), 3);
549            assert_eq!(e.extent_at_dimension_index(0), 2);
550            assert_eq!(e.extent_at_dimension_index(1), 3);
551            assert_eq!(e.extent_at_dimension_index(2), 4);
552        }
553    }
554
555    #[test]
556    fn test_tensor_extents_size() {
557        assert_eq!(
558            std::mem::size_of::<TensorExtents>(),
559            std::mem::size_of::<*mut c_void>()
560        );
561    }
562
563    #[test]
564    fn test_tensor_descriptor_creation() {
565        // TensorDescriptor may not be available on all systems
566        let _descriptor = TensorDescriptor::new();
567    }
568
569    #[test]
570    fn test_tensor_descriptor_size() {
571        assert_eq!(
572            std::mem::size_of::<TensorDescriptor>(),
573            std::mem::size_of::<*mut c_void>()
574        );
575    }
576
577    #[test]
578    fn test_tensor_size() {
579        assert_eq!(
580            std::mem::size_of::<Tensor>(),
581            std::mem::size_of::<*mut c_void>()
582        );
583    }
584
585    #[test]
586    fn test_tensor_descriptor_data_type() {
587        // TensorDescriptor may not be available on all systems
588        if let Some(descriptor) = TensorDescriptor::new() {
589            descriptor.set_data_type(TensorDataType::FLOAT32);
590            assert_eq!(descriptor.data_type(), TensorDataType::FLOAT32);
591        }
592    }
593
594    #[test]
595    fn test_tensor_descriptor_usage() {
596        // TensorDescriptor may not be available on all systems
597        if let Some(descriptor) = TensorDescriptor::new() {
598            descriptor.set_usage(TensorUsage::COMPUTE | TensorUsage::MACHINE_LEARNING);
599            let usage = descriptor.usage();
600            assert!((usage.0 & TensorUsage::COMPUTE.0) != 0);
601            assert!((usage.0 & TensorUsage::MACHINE_LEARNING.0) != 0);
602        }
603    }
604}