Skip to main content

mtl_gpu/enums/
tensor.rs

1//! Tensor enumerations.
2//!
3//! Corresponds to `Metal/MTLTensor.hpp`.
4
5use mtl_foundation::Integer;
6
7/// Tensor data type.
8///
9/// C++ equivalent: `MTL::TensorDataType`
10#[repr(transparent)]
11#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
12pub struct TensorDataType(pub Integer);
13
14impl TensorDataType {
15    pub const NONE: Self = Self(0);
16    pub const FLOAT32: Self = Self(3);
17    pub const FLOAT16: Self = Self(16);
18    pub const BFLOAT16: Self = Self(121);
19    pub const INT8: Self = Self(45);
20    pub const UINT8: Self = Self(49);
21    pub const INT16: Self = Self(37);
22    pub const UINT16: Self = Self(41);
23    pub const INT32: Self = Self(29);
24    pub const UINT32: Self = Self(33);
25}
26
27/// Tensor error codes.
28///
29/// C++ equivalent: `MTL::TensorError`
30#[repr(transparent)]
31#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
32pub struct TensorError(pub Integer);
33
34impl TensorError {
35    pub const NONE: Self = Self(0);
36    pub const INTERNAL_ERROR: Self = Self(1);
37    pub const INVALID_DESCRIPTOR: Self = Self(2);
38}
39
40/// Tensor usage options.
41///
42/// C++ equivalent: `MTL::TensorUsage`
43#[repr(transparent)]
44#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
45pub struct TensorUsage(pub mtl_foundation::UInteger);
46
47impl TensorUsage {
48    pub const COMPUTE: Self = Self(1);
49    pub const RENDER: Self = Self(1 << 1);
50    pub const MACHINE_LEARNING: Self = Self(1 << 2);
51}
52
53impl std::ops::BitOr for TensorUsage {
54    type Output = Self;
55    fn bitor(self, rhs: Self) -> Self {
56        Self(self.0 | rhs.0)
57    }
58}
59
60impl std::ops::BitAnd for TensorUsage {
61    type Output = Self;
62    fn bitand(self, rhs: Self) -> Self {
63        Self(self.0 & rhs.0)
64    }
65}
66
67impl std::ops::BitOrAssign for TensorUsage {
68    fn bitor_assign(&mut self, rhs: Self) {
69        self.0 |= rhs.0;
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn test_tensor_data_type_values() {
79        assert_eq!(TensorDataType::NONE.0, 0);
80        assert_eq!(TensorDataType::FLOAT32.0, 3);
81        assert_eq!(TensorDataType::FLOAT16.0, 16);
82        assert_eq!(TensorDataType::BFLOAT16.0, 121);
83        assert_eq!(TensorDataType::INT8.0, 45);
84        assert_eq!(TensorDataType::UINT8.0, 49);
85        assert_eq!(TensorDataType::INT16.0, 37);
86        assert_eq!(TensorDataType::UINT16.0, 41);
87        assert_eq!(TensorDataType::INT32.0, 29);
88        assert_eq!(TensorDataType::UINT32.0, 33);
89    }
90
91    #[test]
92    fn test_tensor_error_values() {
93        assert_eq!(TensorError::NONE.0, 0);
94        assert_eq!(TensorError::INTERNAL_ERROR.0, 1);
95        assert_eq!(TensorError::INVALID_DESCRIPTOR.0, 2);
96    }
97
98    #[test]
99    fn test_tensor_usage_values() {
100        assert_eq!(TensorUsage::COMPUTE.0, 1);
101        assert_eq!(TensorUsage::RENDER.0, 2);
102        assert_eq!(TensorUsage::MACHINE_LEARNING.0, 4);
103    }
104
105    #[test]
106    fn test_tensor_usage_bitor() {
107        let usage = TensorUsage::COMPUTE | TensorUsage::RENDER;
108        assert_eq!(usage.0, 3);
109    }
110}