Skip to main content

mtl_gpu/sync/
mod.rs

1//! Metal synchronization primitives.
2//!
3//! Corresponds to `Metal/MTLEvent.hpp` and `Metal/MTLFence.hpp`.
4
5use std::ffi::c_void;
6use std::ptr::NonNull;
7
8use mtl_foundation::Referencing;
9use mtl_sys::{msg_send_0, msg_send_1, sel};
10
11// ============================================================================
12// Event
13// ============================================================================
14
15/// A simple synchronization primitive.
16///
17/// C++ equivalent: `MTL::Event`
18#[repr(transparent)]
19pub struct Event(pub(crate) NonNull<c_void>);
20
21impl Event {
22    /// Create an Event from a raw pointer.
23    ///
24    /// # Safety
25    ///
26    /// The pointer must be a valid Metal event object.
27    #[inline]
28    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
29        NonNull::new(ptr).map(Self)
30    }
31
32    /// Get the raw pointer.
33    #[inline]
34    pub fn as_raw(&self) -> *mut c_void {
35        self.0.as_ptr()
36    }
37
38    /// Get the label.
39    ///
40    /// C++ equivalent: `NS::String* label() const`
41    pub fn label(&self) -> Option<String> {
42        unsafe {
43            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
44            if ptr.is_null() {
45                return None;
46            }
47            let utf8_ptr: *const std::ffi::c_char =
48                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
49            if utf8_ptr.is_null() {
50                return None;
51            }
52            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
53            Some(c_str.to_string_lossy().into_owned())
54        }
55    }
56
57    /// Set the label.
58    ///
59    /// C++ equivalent: `void setLabel(const NS::String*)`
60    pub fn set_label(&self, label: &str) {
61        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
62            unsafe {
63                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
64            }
65        }
66    }
67
68    /// Get the device.
69    ///
70    /// C++ equivalent: `Device* device() const`
71    pub fn device(&self) -> crate::Device {
72        unsafe {
73            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
74            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
75            crate::Device::from_raw(ptr).expect("event has no device")
76        }
77    }
78}
79
80impl Clone for Event {
81    fn clone(&self) -> Self {
82        unsafe {
83            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
84        }
85        Self(self.0)
86    }
87}
88
89impl Drop for Event {
90    fn drop(&mut self) {
91        unsafe {
92            msg_send_0::<()>(self.as_ptr(), sel!(release));
93        }
94    }
95}
96
97impl Referencing for Event {
98    #[inline]
99    fn as_ptr(&self) -> *const c_void {
100        self.0.as_ptr()
101    }
102}
103
104unsafe impl Send for Event {}
105unsafe impl Sync for Event {}
106
107impl std::fmt::Debug for Event {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        f.debug_struct("Event")
110            .field("label", &self.label())
111            .finish()
112    }
113}
114
115// ============================================================================
116// SharedEvent
117// ============================================================================
118
119/// A cross-process synchronization primitive.
120///
121/// C++ equivalent: `MTL::SharedEvent`
122#[repr(transparent)]
123pub struct SharedEvent(pub(crate) NonNull<c_void>);
124
125impl SharedEvent {
126    /// Create a SharedEvent from a raw pointer.
127    ///
128    /// # Safety
129    ///
130    /// The pointer must be a valid Metal shared event object.
131    #[inline]
132    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
133        NonNull::new(ptr).map(Self)
134    }
135
136    /// Get the raw pointer.
137    #[inline]
138    pub fn as_raw(&self) -> *mut c_void {
139        self.0.as_ptr()
140    }
141
142    /// Get the label.
143    ///
144    /// C++ equivalent: `NS::String* label() const`
145    pub fn label(&self) -> Option<String> {
146        unsafe {
147            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
148            if ptr.is_null() {
149                return None;
150            }
151            let utf8_ptr: *const std::ffi::c_char =
152                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
153            if utf8_ptr.is_null() {
154                return None;
155            }
156            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
157            Some(c_str.to_string_lossy().into_owned())
158        }
159    }
160
161    /// Set the label.
162    ///
163    /// C++ equivalent: `void setLabel(const NS::String*)`
164    pub fn set_label(&self, label: &str) {
165        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
166            unsafe {
167                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
168            }
169        }
170    }
171
172    /// Get the device.
173    ///
174    /// C++ equivalent: `Device* device() const`
175    pub fn device(&self) -> crate::Device {
176        unsafe {
177            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
178            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
179            crate::Device::from_raw(ptr).expect("shared event has no device")
180        }
181    }
182
183    /// Get the current signaled value.
184    ///
185    /// C++ equivalent: `uint64_t signaledValue() const`
186    #[inline]
187    pub fn signaled_value(&self) -> u64 {
188        unsafe { msg_send_0(self.as_ptr(), sel!(signaledValue)) }
189    }
190
191    /// Set the signaled value.
192    ///
193    /// C++ equivalent: `void setSignaledValue(uint64_t)`
194    #[inline]
195    pub fn set_signaled_value(&self, value: u64) {
196        unsafe {
197            msg_send_1::<(), u64>(self.as_ptr(), sel!(setSignaledValue:), value);
198        }
199    }
200
201    /// Create a handle for sharing across processes.
202    ///
203    /// C++ equivalent: `SharedEventHandle* newSharedEventHandle()`
204    pub fn new_shared_event_handle(&self) -> Option<SharedEventHandle> {
205        unsafe {
206            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(newSharedEventHandle));
207            SharedEventHandle::from_raw(ptr)
208        }
209    }
210
211    /// Notify a listener when the event reaches a value.
212    ///
213    /// C++ equivalent: `void notifyListener(SharedEventListener*, uint64_t, void (^)(SharedEvent*, uint64_t))`
214    ///
215    /// # Safety
216    ///
217    /// The listener pointer must be valid.
218    pub unsafe fn notify_listener<F>(&self, listener: *const c_void, value: u64, block: F)
219    where
220        F: Fn(*mut c_void, u64) + Send + 'static,
221    {
222        let block = mtl_sys::EventBlock::from_fn(block);
223        unsafe {
224            mtl_sys::msg_send_3::<(), *const c_void, u64, *const c_void>(
225                self.as_ptr(),
226                sel!(notifyListener: atValue: block:),
227                listener,
228                value,
229                block.as_ptr(),
230            );
231        }
232        std::mem::forget(block);
233    }
234}
235
236impl Clone for SharedEvent {
237    fn clone(&self) -> Self {
238        unsafe {
239            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
240        }
241        Self(self.0)
242    }
243}
244
245impl Drop for SharedEvent {
246    fn drop(&mut self) {
247        unsafe {
248            msg_send_0::<()>(self.as_ptr(), sel!(release));
249        }
250    }
251}
252
253impl Referencing for SharedEvent {
254    #[inline]
255    fn as_ptr(&self) -> *const c_void {
256        self.0.as_ptr()
257    }
258}
259
260unsafe impl Send for SharedEvent {}
261unsafe impl Sync for SharedEvent {}
262
263impl std::fmt::Debug for SharedEvent {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        f.debug_struct("SharedEvent")
266            .field("label", &self.label())
267            .field("signaled_value", &self.signaled_value())
268            .finish()
269    }
270}
271
272// ============================================================================
273// SharedEventHandle
274// ============================================================================
275
276/// A handle for sharing events across processes.
277///
278/// C++ equivalent: `MTL::SharedEventHandle`
279#[repr(transparent)]
280pub struct SharedEventHandle(pub(crate) NonNull<c_void>);
281
282impl SharedEventHandle {
283    /// Create a SharedEventHandle from a raw pointer.
284    ///
285    /// # Safety
286    ///
287    /// The pointer must be a valid Metal shared event handle object.
288    #[inline]
289    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
290        NonNull::new(ptr).map(Self)
291    }
292
293    /// Get the raw pointer.
294    #[inline]
295    pub fn as_raw(&self) -> *mut c_void {
296        self.0.as_ptr()
297    }
298
299    /// Get the label.
300    ///
301    /// C++ equivalent: `NS::String* label() const`
302    pub fn label(&self) -> Option<String> {
303        unsafe {
304            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
305            if ptr.is_null() {
306                return None;
307            }
308            let utf8_ptr: *const std::ffi::c_char =
309                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
310            if utf8_ptr.is_null() {
311                return None;
312            }
313            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
314            Some(c_str.to_string_lossy().into_owned())
315        }
316    }
317}
318
319impl Clone for SharedEventHandle {
320    fn clone(&self) -> Self {
321        unsafe {
322            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
323        }
324        Self(self.0)
325    }
326}
327
328impl Drop for SharedEventHandle {
329    fn drop(&mut self) {
330        unsafe {
331            msg_send_0::<()>(self.as_ptr(), sel!(release));
332        }
333    }
334}
335
336impl Referencing for SharedEventHandle {
337    #[inline]
338    fn as_ptr(&self) -> *const c_void {
339        self.0.as_ptr()
340    }
341}
342
343unsafe impl Send for SharedEventHandle {}
344unsafe impl Sync for SharedEventHandle {}
345
346// ============================================================================
347// SharedEventListener
348// ============================================================================
349
350/// A listener that can be notified when a shared event reaches a value.
351///
352/// C++ equivalent: `MTL::SharedEventListener`
353#[repr(transparent)]
354pub struct SharedEventListener(pub(crate) NonNull<c_void>);
355
356impl SharedEventListener {
357    /// Allocate a new shared event listener.
358    ///
359    /// C++ equivalent: `static SharedEventListener* alloc()`
360    pub fn alloc() -> Option<Self> {
361        unsafe {
362            let cls = mtl_sys::Class::get("MTLSharedEventListener")?;
363            let ptr: *mut c_void = msg_send_0(cls.as_ptr(), sel!(alloc));
364            Self::from_raw(ptr)
365        }
366    }
367
368    /// Initialize an allocated listener with the default dispatch queue.
369    ///
370    /// C++ equivalent: `SharedEventListener* init()`
371    pub fn init(&self) -> Option<Self> {
372        unsafe {
373            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(init));
374            Self::from_raw(ptr)
375        }
376    }
377
378    /// Create a new shared event listener with the default dispatch queue.
379    pub fn new() -> Option<Self> {
380        Self::alloc()?.init()
381    }
382
383    /// Get the shared (global) listener singleton.
384    ///
385    /// C++ equivalent: `static SharedEventListener* sharedListener()`
386    pub fn shared() -> Option<Self> {
387        unsafe {
388            let cls = mtl_sys::Class::get("MTLSharedEventListener")?;
389            let ptr: *mut c_void = msg_send_0(cls.as_ptr(), sel!(sharedListener));
390            if ptr.is_null() {
391                return None;
392            }
393            msg_send_0::<*mut c_void>(ptr as *const c_void, sel!(retain));
394            Self::from_raw(ptr)
395        }
396    }
397
398    /// Create from a raw pointer.
399    ///
400    /// # Safety
401    ///
402    /// The pointer must be a valid Metal shared event listener.
403    #[inline]
404    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
405        NonNull::new(ptr).map(Self)
406    }
407
408    /// Get the raw pointer.
409    #[inline]
410    pub fn as_raw(&self) -> *mut c_void {
411        self.0.as_ptr()
412    }
413
414    /// Get the dispatch queue for this listener.
415    ///
416    /// C++ equivalent: `dispatch_queue_t dispatchQueue() const`
417    ///
418    /// Returns the raw dispatch_queue_t pointer.
419    #[inline]
420    pub fn dispatch_queue_raw(&self) -> *mut c_void {
421        unsafe { msg_send_0(self.as_ptr(), sel!(dispatchQueue)) }
422    }
423}
424
425impl Clone for SharedEventListener {
426    fn clone(&self) -> Self {
427        unsafe {
428            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
429        }
430        Self(self.0)
431    }
432}
433
434impl Drop for SharedEventListener {
435    fn drop(&mut self) {
436        unsafe {
437            msg_send_0::<()>(self.as_ptr(), sel!(release));
438        }
439    }
440}
441
442impl Referencing for SharedEventListener {
443    #[inline]
444    fn as_ptr(&self) -> *const c_void {
445        self.0.as_ptr()
446    }
447}
448
449unsafe impl Send for SharedEventListener {}
450unsafe impl Sync for SharedEventListener {}
451
452impl std::fmt::Debug for SharedEventListener {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        f.debug_struct("SharedEventListener").finish()
455    }
456}
457
458// ============================================================================
459// Fence
460// ============================================================================
461
462/// A GPU-side synchronization primitive.
463///
464/// C++ equivalent: `MTL::Fence`
465#[repr(transparent)]
466pub struct Fence(pub(crate) NonNull<c_void>);
467
468impl Fence {
469    /// Create a Fence from a raw pointer.
470    ///
471    /// # Safety
472    ///
473    /// The pointer must be a valid Metal fence object.
474    #[inline]
475    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
476        NonNull::new(ptr).map(Self)
477    }
478
479    /// Get the raw pointer.
480    #[inline]
481    pub fn as_raw(&self) -> *mut c_void {
482        self.0.as_ptr()
483    }
484
485    /// Get the label.
486    ///
487    /// C++ equivalent: `NS::String* label() const`
488    pub fn label(&self) -> Option<String> {
489        unsafe {
490            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
491            if ptr.is_null() {
492                return None;
493            }
494            let utf8_ptr: *const std::ffi::c_char =
495                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
496            if utf8_ptr.is_null() {
497                return None;
498            }
499            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
500            Some(c_str.to_string_lossy().into_owned())
501        }
502    }
503
504    /// Set the label.
505    ///
506    /// C++ equivalent: `void setLabel(const NS::String*)`
507    pub fn set_label(&self, label: &str) {
508        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
509            unsafe {
510                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
511            }
512        }
513    }
514
515    /// Get the device.
516    ///
517    /// C++ equivalent: `Device* device() const`
518    pub fn device(&self) -> crate::Device {
519        unsafe {
520            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
521            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
522            crate::Device::from_raw(ptr).expect("fence has no device")
523        }
524    }
525}
526
527impl Clone for Fence {
528    fn clone(&self) -> Self {
529        unsafe {
530            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
531        }
532        Self(self.0)
533    }
534}
535
536impl Drop for Fence {
537    fn drop(&mut self) {
538        unsafe {
539            msg_send_0::<()>(self.as_ptr(), sel!(release));
540        }
541    }
542}
543
544impl Referencing for Fence {
545    #[inline]
546    fn as_ptr(&self) -> *const c_void {
547        self.0.as_ptr()
548    }
549}
550
551unsafe impl Send for Fence {}
552unsafe impl Sync for Fence {}
553
554impl std::fmt::Debug for Fence {
555    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
556        f.debug_struct("Fence")
557            .field("label", &self.label())
558            .finish()
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_event_size() {
568        assert_eq!(
569            std::mem::size_of::<Event>(),
570            std::mem::size_of::<*mut c_void>()
571        );
572    }
573
574    #[test]
575    fn test_shared_event_size() {
576        assert_eq!(
577            std::mem::size_of::<SharedEvent>(),
578            std::mem::size_of::<*mut c_void>()
579        );
580    }
581
582    #[test]
583    fn test_fence_size() {
584        assert_eq!(
585            std::mem::size_of::<Fence>(),
586            std::mem::size_of::<*mut c_void>()
587        );
588    }
589
590    #[test]
591    fn test_shared_event_listener_size() {
592        assert_eq!(
593            std::mem::size_of::<SharedEventListener>(),
594            std::mem::size_of::<*mut c_void>()
595        );
596    }
597}