Skip to main content

mtl_gpu/device/
acceleration.rs

1//! Device acceleration structure creation methods.
2//!
3//! Corresponds to acceleration structure methods in `Metal/MTLDevice.hpp`.
4
5use std::ffi::c_void;
6
7use mtl_foundation::{Referencing, UInteger};
8use mtl_sys::{msg_send_1, sel};
9
10use super::Device;
11use crate::acceleration::{AccelerationStructure, AccelerationStructureSizes};
12
13impl Device {
14    // =========================================================================
15    // Acceleration Structure Creation
16    // =========================================================================
17
18    /// Create an acceleration structure with the given descriptor.
19    ///
20    /// C++ equivalent: `AccelerationStructure* newAccelerationStructure(const AccelerationStructureDescriptor*)`
21    ///
22    /// # Safety
23    ///
24    /// The descriptor pointer must be valid.
25    pub unsafe fn new_acceleration_structure_with_descriptor(
26        &self,
27        descriptor: *const c_void,
28    ) -> Option<AccelerationStructure> {
29        unsafe {
30            let ptr: *mut c_void = msg_send_1(
31                self.as_ptr(),
32                sel!(newAccelerationStructureWithDescriptor:),
33                descriptor,
34            );
35            AccelerationStructure::from_raw(ptr)
36        }
37    }
38
39    /// Create an acceleration structure with the given size.
40    ///
41    /// C++ equivalent: `AccelerationStructure* newAccelerationStructure(NS::UInteger size)`
42    pub fn new_acceleration_structure_with_size(
43        &self,
44        size: UInteger,
45    ) -> Option<AccelerationStructure> {
46        unsafe {
47            let ptr: *mut c_void =
48                msg_send_1(self.as_ptr(), sel!(newAccelerationStructureWithSize:), size);
49            AccelerationStructure::from_raw(ptr)
50        }
51    }
52
53    // =========================================================================
54    // Acceleration Structure Size Queries
55    // =========================================================================
56
57    /// Get the sizes needed for building an acceleration structure.
58    ///
59    /// C++ equivalent: `AccelerationStructureSizes accelerationStructureSizes(const AccelerationStructureDescriptor*)`
60    ///
61    /// # Safety
62    ///
63    /// The descriptor pointer must be valid.
64    pub unsafe fn acceleration_structure_sizes_with_descriptor(
65        &self,
66        descriptor: *const c_void,
67    ) -> AccelerationStructureSizes {
68        unsafe {
69            msg_send_1(
70                self.as_ptr(),
71                sel!(accelerationStructureSizesWithDescriptor:),
72                descriptor,
73            )
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use crate::acceleration::PrimitiveAccelerationStructureDescriptor;
81    use crate::device::system_default;
82
83    #[test]
84    fn test_supports_raytracing() {
85        let device = system_default().expect("no Metal device");
86        // Just verify we can call the method - result depends on hardware
87        let _supports = device.supports_raytracing();
88    }
89
90    #[test]
91    fn test_acceleration_structure_sizes() {
92        let device = system_default().expect("no Metal device");
93
94        if !device.supports_raytracing() {
95            println!("Skipping test - device does not support ray tracing");
96            return;
97        }
98
99        let desc =
100            PrimitiveAccelerationStructureDescriptor::new().expect("failed to create descriptor");
101
102        unsafe {
103            let sizes = device.acceleration_structure_sizes_with_descriptor(desc.as_raw());
104            // Empty descriptor should have 0 sizes
105            println!(
106                "Acceleration structure size: {}",
107                sizes.acceleration_structure_size
108            );
109            println!(
110                "Build scratch buffer size: {}",
111                sizes.build_scratch_buffer_size
112            );
113            println!(
114                "Refit scratch buffer size: {}",
115                sizes.refit_scratch_buffer_size
116            );
117        }
118    }
119
120    #[test]
121    fn test_new_acceleration_structure_with_size() {
122        let device = system_default().expect("no Metal device");
123
124        if !device.supports_raytracing() {
125            println!("Skipping test - device does not support ray tracing");
126            return;
127        }
128
129        // Create a small acceleration structure
130        let accel = device.new_acceleration_structure_with_size(1024);
131        assert!(accel.is_some());
132
133        let accel = accel.unwrap();
134        assert!(accel.size() >= 1024);
135    }
136}