Skip to main content

mtl_gpu/argument/
tensor_binding.rs

1//! Tensor binding information.
2
3use std::ffi::c_void;
4use std::ptr::NonNull;
5
6use mtl_foundation::{Referencing, UInteger};
7use mtl_sys::{msg_send_0, sel};
8
9use crate::enums::{BindingAccess, BindingType, DataType, TensorDataType};
10
11/// Tensor binding information.
12///
13/// C++ equivalent: `MTL::TensorBinding`
14#[repr(transparent)]
15pub struct TensorBinding(pub(crate) NonNull<c_void>);
16
17impl TensorBinding {
18    /// Create from a raw pointer.
19    ///
20    /// # Safety
21    ///
22    /// The pointer must be a valid Metal TensorBinding.
23    #[inline]
24    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
25        NonNull::new(ptr).map(Self)
26    }
27
28    /// Get the raw pointer.
29    #[inline]
30    pub fn as_raw(&self) -> *mut c_void {
31        self.0.as_ptr()
32    }
33
34    // Inherited from Binding
35
36    /// Get the access mode.
37    #[inline]
38    pub fn access(&self) -> BindingAccess {
39        unsafe { msg_send_0(self.as_ptr(), sel!(access)) }
40    }
41
42    /// Get the index.
43    #[inline]
44    pub fn index(&self) -> UInteger {
45        unsafe { msg_send_0(self.as_ptr(), sel!(index)) }
46    }
47
48    /// Check if this is an argument.
49    #[inline]
50    pub fn is_argument(&self) -> bool {
51        unsafe { msg_send_0(self.as_ptr(), sel!(isArgument)) }
52    }
53
54    /// Check if this binding is used.
55    #[inline]
56    pub fn is_used(&self) -> bool {
57        unsafe { msg_send_0(self.as_ptr(), sel!(isUsed)) }
58    }
59
60    /// Get the name.
61    pub fn name(&self) -> Option<String> {
62        unsafe {
63            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(name));
64            if ptr.is_null() {
65                return None;
66            }
67            let utf8_ptr: *const std::ffi::c_char =
68                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
69            if utf8_ptr.is_null() {
70                return None;
71            }
72            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
73            Some(c_str.to_string_lossy().into_owned())
74        }
75    }
76
77    /// Get the binding type.
78    #[inline]
79    pub fn binding_type(&self) -> BindingType {
80        unsafe { msg_send_0(self.as_ptr(), sel!(type)) }
81    }
82
83    // TensorBinding-specific
84
85    /// Get the dimensions as a raw pointer.
86    ///
87    /// C++ equivalent: `TensorExtents* dimensions() const`
88    #[inline]
89    pub fn dimensions_ptr(&self) -> *const c_void {
90        unsafe { msg_send_0(self.as_ptr(), sel!(dimensions)) }
91    }
92
93    /// Get the index type.
94    ///
95    /// C++ equivalent: `DataType indexType() const`
96    #[inline]
97    pub fn index_type(&self) -> DataType {
98        unsafe { msg_send_0(self.as_ptr(), sel!(indexType)) }
99    }
100
101    /// Get the tensor data type.
102    ///
103    /// C++ equivalent: `TensorDataType tensorDataType() const`
104    #[inline]
105    pub fn tensor_data_type(&self) -> TensorDataType {
106        unsafe { msg_send_0(self.as_ptr(), sel!(tensorDataType)) }
107    }
108}
109
110impl Clone for TensorBinding {
111    fn clone(&self) -> Self {
112        unsafe {
113            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
114        }
115        Self(self.0)
116    }
117}
118
119impl Drop for TensorBinding {
120    fn drop(&mut self) {
121        unsafe {
122            msg_send_0::<()>(self.as_ptr(), sel!(release));
123        }
124    }
125}
126
127impl Referencing for TensorBinding {
128    #[inline]
129    fn as_ptr(&self) -> *const c_void {
130        self.0.as_ptr()
131    }
132}
133
134unsafe impl Send for TensorBinding {}
135unsafe impl Sync for TensorBinding {}