Skip to main content

mtl_gpu/encoder/compute_encoder/
dispatch.rs

1//! Dispatch methods for ComputeCommandEncoder.
2
3use std::ffi::c_void;
4
5use mtl_foundation::{Referencing, UInteger};
6use mtl_sys::{msg_send_1, sel};
7
8use crate::Buffer;
9use crate::types::{Region, Size};
10
11use super::ComputeCommandEncoder;
12
13impl ComputeCommandEncoder {
14    // =========================================================================
15    // Stage-In Region
16    // =========================================================================
17
18    /// Set the stage-in region.
19    ///
20    /// C++ equivalent: `void setStageInRegion(MTL::Region)`
21    #[inline]
22    pub fn set_stage_in_region(&self, region: Region) {
23        unsafe {
24            msg_send_1::<(), Region>(self.as_ptr(), sel!(setStageInRegion:), region);
25        }
26    }
27
28    /// Set the stage-in region from an indirect buffer.
29    ///
30    /// C++ equivalent: `void setStageInRegion(const Buffer*, NS::UInteger)`
31    #[inline]
32    pub fn set_stage_in_region_with_indirect_buffer(
33        &self,
34        indirect_buffer: &Buffer,
35        indirect_buffer_offset: UInteger,
36    ) {
37        unsafe {
38            mtl_sys::msg_send_2::<(), *const c_void, UInteger>(
39                self.as_ptr(),
40                sel!(setStageInRegionWithIndirectBuffer: indirectBufferOffset:),
41                indirect_buffer.as_ptr(),
42                indirect_buffer_offset,
43            );
44        }
45    }
46
47    // =========================================================================
48    // Dispatch
49    // =========================================================================
50
51    /// Dispatch threadgroups.
52    ///
53    /// C++ equivalent: `void dispatchThreadgroups(MTL::Size, MTL::Size)`
54    #[inline]
55    pub fn dispatch_threadgroups(
56        &self,
57        threadgroups_per_grid: Size,
58        threads_per_threadgroup: Size,
59    ) {
60        unsafe {
61            mtl_sys::msg_send_2::<(), Size, Size>(
62                self.as_ptr(),
63                sel!(dispatchThreadgroups: threadsPerThreadgroup:),
64                threadgroups_per_grid,
65                threads_per_threadgroup,
66            );
67        }
68    }
69
70    /// Dispatch threadgroups with an indirect buffer.
71    ///
72    /// C++ equivalent: `void dispatchThreadgroups(const Buffer*, NS::UInteger, MTL::Size)`
73    #[inline]
74    pub fn dispatch_threadgroups_with_indirect_buffer(
75        &self,
76        indirect_buffer: &Buffer,
77        indirect_buffer_offset: UInteger,
78        threads_per_threadgroup: Size,
79    ) {
80        unsafe {
81            mtl_sys::msg_send_3::<(), *const c_void, UInteger, Size>(
82                self.as_ptr(),
83                sel!(dispatchThreadgroupsWithIndirectBuffer: indirectBufferOffset: threadsPerThreadgroup:),
84                indirect_buffer.as_ptr(),
85                indirect_buffer_offset,
86                threads_per_threadgroup,
87            );
88        }
89    }
90
91    /// Dispatch threads directly (non-uniform dispatch).
92    ///
93    /// C++ equivalent: `void dispatchThreads(MTL::Size, MTL::Size)`
94    #[inline]
95    pub fn dispatch_threads(&self, threads_per_grid: Size, threads_per_threadgroup: Size) {
96        unsafe {
97            mtl_sys::msg_send_2::<(), Size, Size>(
98                self.as_ptr(),
99                sel!(dispatchThreads: threadsPerThreadgroup:),
100                threads_per_grid,
101                threads_per_threadgroup,
102            );
103        }
104    }
105}