mtl_gpu/pipeline/
compute_state.rs1use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::{Referencing, UInteger};
9use mtl_sys::{msg_send_0, msg_send_1, msg_send_2, sel};
10
11use crate::enums::ShaderValidation;
12use crate::types::{ResourceID, Size};
13
14use super::ComputePipelineReflection;
15
16#[repr(transparent)]
20pub struct ComputePipelineState(pub(crate) NonNull<c_void>);
21
22impl ComputePipelineState {
23 #[inline]
29 pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
30 NonNull::new(ptr).map(Self)
31 }
32
33 #[inline]
35 pub fn as_raw(&self) -> *mut c_void {
36 self.0.as_ptr()
37 }
38
39 pub fn label(&self) -> Option<String> {
47 unsafe {
48 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
49 if ptr.is_null() {
50 return None;
51 }
52 let utf8_ptr: *const std::ffi::c_char =
53 mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
54 if utf8_ptr.is_null() {
55 return None;
56 }
57 let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
58 Some(c_str.to_string_lossy().into_owned())
59 }
60 }
61
62 pub fn device(&self) -> crate::Device {
66 unsafe {
67 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
68 let _: *mut c_void = msg_send_0(ptr, sel!(retain));
69 crate::Device::from_raw(ptr).expect("pipeline state has no device")
70 }
71 }
72
73 #[inline]
77 pub fn max_total_threads_per_threadgroup(&self) -> UInteger {
78 unsafe { msg_send_0(self.as_ptr(), sel!(maxTotalThreadsPerThreadgroup)) }
79 }
80
81 #[inline]
85 pub fn thread_execution_width(&self) -> UInteger {
86 unsafe { msg_send_0(self.as_ptr(), sel!(threadExecutionWidth)) }
87 }
88
89 #[inline]
93 pub fn static_threadgroup_memory_length(&self) -> UInteger {
94 unsafe { msg_send_0(self.as_ptr(), sel!(staticThreadgroupMemoryLength)) }
95 }
96
97 #[inline]
101 pub fn support_indirect_command_buffers(&self) -> bool {
102 unsafe { msg_send_0(self.as_ptr(), sel!(supportIndirectCommandBuffers)) }
103 }
104
105 #[inline]
109 pub fn imageblock_memory_length(&self, dimensions: Size) -> UInteger {
110 unsafe {
111 msg_send_1(
112 self.as_ptr(),
113 sel!(imageblockMemoryLengthForDimensions:),
114 dimensions,
115 )
116 }
117 }
118
119 #[inline]
123 pub fn gpu_resource_id(&self) -> ResourceID {
124 unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
125 }
126
127 #[inline]
131 pub fn shader_validation(&self) -> ShaderValidation {
132 unsafe { msg_send_0(self.as_ptr(), sel!(shaderValidation)) }
133 }
134
135 #[inline]
139 pub fn required_threads_per_threadgroup(&self) -> Size {
140 unsafe { msg_send_0(self.as_ptr(), sel!(requiredThreadsPerThreadgroup)) }
141 }
142
143 pub fn function_handle_with_name(&self, name: &str) -> Option<crate::FunctionHandle> {
151 let ns_name = mtl_foundation::String::from_str(name)?;
152 unsafe {
153 let ptr: *mut c_void = msg_send_1(
154 self.as_ptr(),
155 sel!(functionHandleWithName:),
156 ns_name.as_ptr(),
157 );
158 crate::FunctionHandle::from_raw(ptr)
159 }
160 }
161
162 pub fn function_handle_with_function(
166 &self,
167 function: &crate::Function,
168 ) -> Option<crate::FunctionHandle> {
169 unsafe {
170 let ptr: *mut c_void = msg_send_1(
171 self.as_ptr(),
172 sel!(functionHandleWithFunction:),
173 function.as_ptr(),
174 );
175 crate::FunctionHandle::from_raw(ptr)
176 }
177 }
178
179 pub fn function_handle_with_binary_function(
183 &self,
184 binary_function: *const c_void,
185 ) -> Option<crate::FunctionHandle> {
186 unsafe {
187 let ptr: *mut c_void = msg_send_1(
188 self.as_ptr(),
189 sel!(functionHandleWithBinaryFunction:),
190 binary_function,
191 );
192 crate::FunctionHandle::from_raw(ptr)
193 }
194 }
195
196 pub fn new_compute_pipeline_state_with_functions(
204 &self,
205 functions: *const c_void,
206 ) -> Result<ComputePipelineState, mtl_foundation::Error> {
207 unsafe {
208 let mut error: *mut c_void = std::ptr::null_mut();
209 let ptr: *mut c_void = msg_send_2(
210 self.as_ptr(),
211 sel!(newComputePipelineStateWithFunctions:error:),
212 functions,
213 &mut error,
214 );
215 if ptr.is_null() {
216 if !error.is_null() {
217 return Err(
218 mtl_foundation::Error::from_ptr(error).expect("error should be valid")
219 );
220 }
221 return Err(mtl_foundation::Error::error(
222 std::ptr::null_mut(),
223 -1,
224 std::ptr::null_mut(),
225 )
226 .expect("failed to create error"));
227 }
228 Ok(ComputePipelineState::from_raw(ptr).unwrap())
229 }
230 }
231
232 pub fn new_compute_pipeline_state_with_binary_functions(
236 &self,
237 binary_functions: *const c_void,
238 ) -> Result<ComputePipelineState, mtl_foundation::Error> {
239 unsafe {
240 let mut error: *mut c_void = std::ptr::null_mut();
241 let ptr: *mut c_void = msg_send_2(
242 self.as_ptr(),
243 sel!(newComputePipelineStateWithBinaryFunctions:error:),
244 binary_functions,
245 &mut error,
246 );
247 if ptr.is_null() {
248 if !error.is_null() {
249 return Err(
250 mtl_foundation::Error::from_ptr(error).expect("error should be valid")
251 );
252 }
253 return Err(mtl_foundation::Error::error(
254 std::ptr::null_mut(),
255 -1,
256 std::ptr::null_mut(),
257 )
258 .expect("failed to create error"));
259 }
260 Ok(ComputePipelineState::from_raw(ptr).unwrap())
261 }
262 }
263
264 pub fn new_intersection_function_table(
272 &self,
273 descriptor: &crate::IntersectionFunctionTableDescriptor,
274 ) -> Option<crate::IntersectionFunctionTable> {
275 unsafe {
276 let ptr: *mut c_void = msg_send_1(
277 self.as_ptr(),
278 sel!(newIntersectionFunctionTableWithDescriptor:),
279 descriptor.as_ptr(),
280 );
281 crate::IntersectionFunctionTable::from_raw(ptr)
282 }
283 }
284
285 pub fn new_visible_function_table(
289 &self,
290 descriptor: &crate::VisibleFunctionTableDescriptor,
291 ) -> Option<crate::VisibleFunctionTable> {
292 unsafe {
293 let ptr: *mut c_void = msg_send_1(
294 self.as_ptr(),
295 sel!(newVisibleFunctionTableWithDescriptor:),
296 descriptor.as_ptr(),
297 );
298 crate::VisibleFunctionTable::from_raw(ptr)
299 }
300 }
301
302 pub fn reflection(&self) -> Option<ComputePipelineReflection> {
310 unsafe {
311 let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(reflection));
312 ComputePipelineReflection::from_raw(ptr)
313 }
314 }
315}
316
317impl Clone for ComputePipelineState {
318 fn clone(&self) -> Self {
319 unsafe {
320 msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
321 }
322 Self(self.0)
323 }
324}
325
326impl Drop for ComputePipelineState {
327 fn drop(&mut self) {
328 unsafe {
329 msg_send_0::<()>(self.as_ptr(), sel!(release));
330 }
331 }
332}
333
334impl Referencing for ComputePipelineState {
335 #[inline]
336 fn as_ptr(&self) -> *const c_void {
337 self.0.as_ptr()
338 }
339}
340
341unsafe impl Send for ComputePipelineState {}
342unsafe impl Sync for ComputePipelineState {}
343
344impl std::fmt::Debug for ComputePipelineState {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 f.debug_struct("ComputePipelineState")
347 .field("label", &self.label())
348 .field(
349 "max_total_threads_per_threadgroup",
350 &self.max_total_threads_per_threadgroup(),
351 )
352 .field("thread_execution_width", &self.thread_execution_width())
353 .finish()
354 }
355}