Skip to main content

mtl_gpu/device/
creation.rs

1//! Device creation and enumeration.
2//!
3//! Corresponds to the free functions in `Metal/MTLDevice.hpp`:
4//! - `MTLCreateSystemDefaultDevice()`
5//! - `MTLCopyAllDevices()` (macOS only)
6//! - `MTLCopyAllDevicesWithObserver()` (macOS only)
7//! - `MTLRemoveDeviceObserver()` (macOS only)
8
9use std::ffi::c_void;
10
11use mtl_foundation::Referencing;
12use mtl_sys::{DeviceObserver as SysDeviceObserver, create_system_default_device};
13
14#[cfg(target_os = "macos")]
15use mtl_sys::copy_all_devices as sys_copy_all_devices;
16
17use super::Device;
18
19/// Timestamp type for CPU/GPU timestamp queries.
20///
21/// C++ equivalent: `using Timestamp = std::uint64_t`
22pub type Timestamp = u64;
23
24/// Create the system default Metal device.
25///
26/// Returns the default GPU for this system. On systems with multiple GPUs,
27/// this typically returns the most capable GPU.
28///
29/// C++ equivalent: `MTL::Device* MTL::CreateSystemDefaultDevice()`
30///
31/// # Example
32///
33/// ```ignore
34/// use mtl_gpu::device;
35///
36/// let device = device::system_default().expect("no Metal device available");
37/// println!("GPU: {}", device.name());
38/// ```
39#[inline]
40pub fn system_default() -> Option<Device> {
41    unsafe {
42        let ptr = create_system_default_device()?;
43        // The returned pointer is autoreleased, so we need to retain it
44        let _: *mut c_void = mtl_sys::msg_send_0(ptr, mtl_sys::sel!(retain));
45        Device::from_raw(ptr)
46    }
47}
48
49/// Copy all available Metal devices (macOS only).
50///
51/// Returns an array of all GPUs available on this system.
52///
53/// C++ equivalent: `NS::Array* MTL::CopyAllDevices()`
54///
55/// # Example
56///
57/// ```ignore
58/// use mtl_gpu::device;
59///
60/// let devices = device::copy_all_devices();
61/// for device in devices {
62///     println!("Found GPU: {}", device.name());
63/// }
64/// ```
65#[cfg(target_os = "macos")]
66pub fn copy_all_devices() -> Vec<Device> {
67    unsafe {
68        let array_ptr = match sys_copy_all_devices() {
69            Some(ptr) => ptr,
70            None => return Vec::new(),
71        };
72
73        // Get the count
74        let count: usize = mtl_sys::msg_send_0(array_ptr, mtl_sys::sel!(count));
75
76        let mut devices = Vec::with_capacity(count);
77        for i in 0..count {
78            let device_ptr: *mut c_void =
79                mtl_sys::msg_send_1(array_ptr, mtl_sys::sel!(objectAtIndex:), i);
80            if let Some(device) = Device::from_raw(device_ptr) {
81                // Retain since objectAtIndex: returns an autoreleased reference
82                device.retain();
83                devices.push(device);
84            }
85        }
86
87        // Release the array (CopyAllDevices transfers ownership)
88        mtl_sys::msg_send_0::<()>(array_ptr, mtl_sys::sel!(release));
89
90        devices
91    }
92}
93
94/// Device observer handle.
95///
96/// Used to receive notifications about device addition/removal.
97/// Call [`remove_device_observer`] when done observing.
98#[cfg(target_os = "macos")]
99pub struct DeviceObserver(SysDeviceObserver);
100
101#[cfg(target_os = "macos")]
102impl DeviceObserver {
103    /// Create from the raw observer handle.
104    ///
105    /// # Safety
106    ///
107    /// The handle must be a valid device observer.
108    pub unsafe fn from_raw(observer: SysDeviceObserver) -> Self {
109        Self(observer)
110    }
111
112    /// Get the raw observer handle.
113    pub fn as_raw(&self) -> SysDeviceObserver {
114        self.0
115    }
116}
117
118/// Copy all devices with an observer for hot-plug notifications (macOS only).
119///
120/// This function returns all available devices and sets up an observer
121/// that will be notified when devices are added or removed.
122///
123/// C++ equivalent: `NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object**, handler)`
124///
125/// # Example
126///
127/// ```ignore
128/// use mtl_gpu::device;
129///
130/// let (devices, observer) = device::copy_all_devices_with_observer(|device, notification| {
131///     println!("Device event: {:?}", notification);
132/// });
133///
134/// // ... use devices ...
135///
136/// // When done observing:
137/// device::remove_device_observer(observer);
138/// ```
139#[cfg(target_os = "macos")]
140pub fn copy_all_devices_with_observer<F>(handler: F) -> (Vec<Device>, DeviceObserver)
141where
142    F: Fn(&Device, DeviceNotificationName) + Send + 'static,
143{
144    use mtl_sys::MTLCopyAllDevicesWithObserver;
145
146    // Create the block that wraps the handler
147    let block = mtl_sys::TwoArgBlock::from_fn(
148        move |device_ptr: *mut c_void, notification_name_ptr: *mut c_void| {
149            unsafe {
150                if let Some(device) = Device::from_raw(device_ptr) {
151                    // Parse the notification name from the NSString
152                    let notification = parse_notification_name(notification_name_ptr);
153                    handler(&device, notification);
154                    // Don't drop - Metal owns this reference
155                    std::mem::forget(device);
156                }
157            }
158        },
159    );
160
161    let mut observer: SysDeviceObserver = std::ptr::null_mut();
162
163    let array_ptr =
164        unsafe { MTLCopyAllDevicesWithObserver(&mut observer as *mut _, block.as_ptr()) };
165
166    // Transfer block ownership to Metal
167    std::mem::forget(block);
168
169    // Parse the devices array
170    let devices = if array_ptr.is_null() {
171        Vec::new()
172    } else {
173        unsafe {
174            let count: usize = mtl_sys::msg_send_0(array_ptr, mtl_sys::sel!(count));
175            let mut devices = Vec::with_capacity(count);
176            for i in 0..count {
177                let device_ptr: *mut c_void =
178                    mtl_sys::msg_send_1(array_ptr, mtl_sys::sel!(objectAtIndex:), i);
179                if let Some(device) = Device::from_raw(device_ptr) {
180                    // Retain since objectAtIndex: returns an autoreleased reference
181                    device.retain();
182                    devices.push(device);
183                }
184            }
185            // Release the array (CopyAllDevicesWithObserver transfers ownership)
186            mtl_sys::msg_send_0::<()>(array_ptr, mtl_sys::sel!(release));
187            devices
188        }
189    };
190
191    (devices, unsafe { DeviceObserver::from_raw(observer) })
192}
193
194/// Parse a notification name NSString into our enum.
195#[cfg(target_os = "macos")]
196fn parse_notification_name(ns_string: *mut c_void) -> DeviceNotificationName {
197    if ns_string.is_null() {
198        return DeviceNotificationName::WasRemoved; // Default fallback
199    }
200
201    unsafe {
202        let c_str: *const i8 = mtl_sys::msg_send_0(ns_string, mtl_sys::sel!(UTF8String));
203        if c_str.is_null() {
204            return DeviceNotificationName::WasRemoved;
205        }
206
207        let name = std::ffi::CStr::from_ptr(c_str).to_string_lossy();
208
209        // These are the standard Metal notification names
210        if name.contains("WasAdded") {
211            DeviceNotificationName::WasAdded
212        } else if name.contains("RemovalRequested") {
213            DeviceNotificationName::RemovalRequested
214        } else if name.contains("WasRemoved") {
215            DeviceNotificationName::WasRemoved
216        } else {
217            // Unknown notification, default to WasRemoved
218            DeviceNotificationName::WasRemoved
219        }
220    }
221}
222
223/// Device notification name.
224#[cfg(target_os = "macos")]
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum DeviceNotificationName {
227    /// A device was added to the system.
228    WasAdded,
229    /// A device removal was requested.
230    RemovalRequested,
231    /// A device was removed from the system.
232    WasRemoved,
233}
234
235/// Remove a device observer (macOS only).
236///
237/// Call this function when you no longer need to receive device notifications.
238///
239/// C++ equivalent: `void MTL::RemoveDeviceObserver(const NS::Object* pObserver)`
240#[cfg(target_os = "macos")]
241pub fn remove_device_observer(observer: DeviceObserver) {
242    use mtl_sys::MTLRemoveDeviceObserver;
243    unsafe {
244        MTLRemoveDeviceObserver(observer.0);
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_system_default() {
254        let device = system_default();
255        assert!(device.is_some(), "should have at least one Metal device");
256
257        let device = device.unwrap();
258        assert!(!device.name().is_empty(), "device should have a name");
259    }
260
261    #[cfg(target_os = "macos")]
262    #[test]
263    fn test_copy_all_devices() {
264        let devices = copy_all_devices();
265        assert!(!devices.is_empty(), "should have at least one device");
266    }
267}