Skip to main content

mtl_gpu/sampler/
mod.rs

1//! Metal sampler state.
2//!
3//! Corresponds to `Metal/MTLSampler.hpp`.
4//!
5//! Sampler states define how textures are sampled in shaders.
6
7use std::ffi::c_void;
8use std::ptr::NonNull;
9
10use mtl_foundation::{Referencing, UInteger};
11use mtl_sys::{msg_send_0, msg_send_1, sel};
12
13use crate::enums::{
14    CompareFunction, SamplerAddressMode, SamplerBorderColor, SamplerMinMagFilter, SamplerMipFilter,
15    SamplerReductionMode,
16};
17use crate::types::ResourceID;
18
19/// An object that defines how texture coordinates map to texels.
20///
21/// C++ equivalent: `MTL::SamplerState`
22#[repr(transparent)]
23pub struct SamplerState(pub(crate) NonNull<c_void>);
24
25impl SamplerState {
26    /// Create a SamplerState from a raw pointer.
27    ///
28    /// # Safety
29    ///
30    /// The pointer must be a valid Metal sampler state object.
31    #[inline]
32    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
33        NonNull::new(ptr).map(Self)
34    }
35
36    /// Get the raw pointer to the sampler state.
37    #[inline]
38    pub fn as_raw(&self) -> *mut c_void {
39        self.0.as_ptr()
40    }
41
42    // =========================================================================
43    // Properties
44    // =========================================================================
45
46    /// Get the label for this sampler state.
47    ///
48    /// C++ equivalent: `NS::String* label() const`
49    pub fn label(&self) -> Option<String> {
50        unsafe {
51            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
52            if ptr.is_null() {
53                return None;
54            }
55            let utf8_ptr: *const std::ffi::c_char =
56                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
57            if utf8_ptr.is_null() {
58                return None;
59            }
60            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
61            Some(c_str.to_string_lossy().into_owned())
62        }
63    }
64
65    /// Get the device that created this sampler state.
66    ///
67    /// C++ equivalent: `Device* device() const`
68    pub fn device(&self) -> crate::Device {
69        unsafe {
70            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(device));
71            let _: *mut c_void = msg_send_0(ptr, sel!(retain));
72            crate::Device::from_raw(ptr).expect("sampler state has no device")
73        }
74    }
75
76    /// Get the GPU resource ID for bindless access.
77    ///
78    /// C++ equivalent: `ResourceID gpuResourceID() const`
79    #[inline]
80    pub fn gpu_resource_id(&self) -> ResourceID {
81        unsafe { msg_send_0(self.as_ptr(), sel!(gpuResourceID)) }
82    }
83}
84
85impl Clone for SamplerState {
86    fn clone(&self) -> Self {
87        unsafe {
88            msg_send_0::<*mut c_void>(self.as_ptr(), sel!(retain));
89        }
90        Self(self.0)
91    }
92}
93
94impl Drop for SamplerState {
95    fn drop(&mut self) {
96        unsafe {
97            msg_send_0::<()>(self.as_ptr(), sel!(release));
98        }
99    }
100}
101
102impl Referencing for SamplerState {
103    #[inline]
104    fn as_ptr(&self) -> *const c_void {
105        self.0.as_ptr()
106    }
107}
108
109unsafe impl Send for SamplerState {}
110unsafe impl Sync for SamplerState {}
111
112impl std::fmt::Debug for SamplerState {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("SamplerState")
115            .field("label", &self.label())
116            .finish()
117    }
118}
119
120// ============================================================================
121// Sampler Descriptor
122// ============================================================================
123
124/// A configuration for a sampler state.
125///
126/// C++ equivalent: `MTL::SamplerDescriptor`
127#[repr(transparent)]
128pub struct SamplerDescriptor(pub(crate) NonNull<c_void>);
129
130impl SamplerDescriptor {
131    /// Create a new sampler descriptor.
132    ///
133    /// C++ equivalent: `static SamplerDescriptor* alloc()->init()`
134    pub fn new() -> Option<Self> {
135        unsafe {
136            let class = mtl_sys::Class::get("MTLSamplerDescriptor")?;
137            let ptr: *mut c_void = msg_send_0(class.as_ptr(), sel!(alloc));
138            if ptr.is_null() {
139                return None;
140            }
141            let ptr: *mut c_void = msg_send_0(ptr, sel!(init));
142            Self::from_raw(ptr)
143        }
144    }
145
146    /// Create a SamplerDescriptor from a raw pointer.
147    ///
148    /// # Safety
149    ///
150    /// The pointer must be a valid Metal sampler descriptor object.
151    #[inline]
152    pub unsafe fn from_raw(ptr: *mut c_void) -> Option<Self> {
153        NonNull::new(ptr).map(Self)
154    }
155
156    /// Get the raw pointer.
157    #[inline]
158    pub fn as_raw(&self) -> *mut c_void {
159        self.0.as_ptr()
160    }
161
162    // =========================================================================
163    // Filter Properties
164    // =========================================================================
165
166    /// Get the minification filter.
167    ///
168    /// C++ equivalent: `SamplerMinMagFilter minFilter() const`
169    #[inline]
170    pub fn min_filter(&self) -> SamplerMinMagFilter {
171        unsafe { msg_send_0(self.as_ptr(), sel!(minFilter)) }
172    }
173
174    /// Set the minification filter.
175    ///
176    /// C++ equivalent: `void setMinFilter(SamplerMinMagFilter)`
177    #[inline]
178    pub fn set_min_filter(&self, filter: SamplerMinMagFilter) {
179        unsafe {
180            msg_send_1::<(), SamplerMinMagFilter>(self.as_ptr(), sel!(setMinFilter:), filter);
181        }
182    }
183
184    /// Get the magnification filter.
185    ///
186    /// C++ equivalent: `SamplerMinMagFilter magFilter() const`
187    #[inline]
188    pub fn mag_filter(&self) -> SamplerMinMagFilter {
189        unsafe { msg_send_0(self.as_ptr(), sel!(magFilter)) }
190    }
191
192    /// Set the magnification filter.
193    ///
194    /// C++ equivalent: `void setMagFilter(SamplerMinMagFilter)`
195    #[inline]
196    pub fn set_mag_filter(&self, filter: SamplerMinMagFilter) {
197        unsafe {
198            msg_send_1::<(), SamplerMinMagFilter>(self.as_ptr(), sel!(setMagFilter:), filter);
199        }
200    }
201
202    /// Get the mipmap filter.
203    ///
204    /// C++ equivalent: `SamplerMipFilter mipFilter() const`
205    #[inline]
206    pub fn mip_filter(&self) -> SamplerMipFilter {
207        unsafe { msg_send_0(self.as_ptr(), sel!(mipFilter)) }
208    }
209
210    /// Set the mipmap filter.
211    ///
212    /// C++ equivalent: `void setMipFilter(SamplerMipFilter)`
213    #[inline]
214    pub fn set_mip_filter(&self, filter: SamplerMipFilter) {
215        unsafe {
216            msg_send_1::<(), SamplerMipFilter>(self.as_ptr(), sel!(setMipFilter:), filter);
217        }
218    }
219
220    // =========================================================================
221    // Address Mode Properties
222    // =========================================================================
223
224    /// Get the S (horizontal) address mode.
225    ///
226    /// C++ equivalent: `SamplerAddressMode sAddressMode() const`
227    #[inline]
228    pub fn s_address_mode(&self) -> SamplerAddressMode {
229        unsafe { msg_send_0(self.as_ptr(), sel!(sAddressMode)) }
230    }
231
232    /// Set the S (horizontal) address mode.
233    ///
234    /// C++ equivalent: `void setSAddressMode(SamplerAddressMode)`
235    #[inline]
236    pub fn set_s_address_mode(&self, mode: SamplerAddressMode) {
237        unsafe {
238            msg_send_1::<(), SamplerAddressMode>(self.as_ptr(), sel!(setSAddressMode:), mode);
239        }
240    }
241
242    /// Get the T (vertical) address mode.
243    ///
244    /// C++ equivalent: `SamplerAddressMode tAddressMode() const`
245    #[inline]
246    pub fn t_address_mode(&self) -> SamplerAddressMode {
247        unsafe { msg_send_0(self.as_ptr(), sel!(tAddressMode)) }
248    }
249
250    /// Set the T (vertical) address mode.
251    ///
252    /// C++ equivalent: `void setTAddressMode(SamplerAddressMode)`
253    #[inline]
254    pub fn set_t_address_mode(&self, mode: SamplerAddressMode) {
255        unsafe {
256            msg_send_1::<(), SamplerAddressMode>(self.as_ptr(), sel!(setTAddressMode:), mode);
257        }
258    }
259
260    /// Get the R (depth) address mode.
261    ///
262    /// C++ equivalent: `SamplerAddressMode rAddressMode() const`
263    #[inline]
264    pub fn r_address_mode(&self) -> SamplerAddressMode {
265        unsafe { msg_send_0(self.as_ptr(), sel!(rAddressMode)) }
266    }
267
268    /// Set the R (depth) address mode.
269    ///
270    /// C++ equivalent: `void setRAddressMode(SamplerAddressMode)`
271    #[inline]
272    pub fn set_r_address_mode(&self, mode: SamplerAddressMode) {
273        unsafe {
274            msg_send_1::<(), SamplerAddressMode>(self.as_ptr(), sel!(setRAddressMode:), mode);
275        }
276    }
277
278    /// Get the border color.
279    ///
280    /// C++ equivalent: `SamplerBorderColor borderColor() const`
281    #[inline]
282    pub fn border_color(&self) -> SamplerBorderColor {
283        unsafe { msg_send_0(self.as_ptr(), sel!(borderColor)) }
284    }
285
286    /// Set the border color.
287    ///
288    /// C++ equivalent: `void setBorderColor(SamplerBorderColor)`
289    #[inline]
290    pub fn set_border_color(&self, color: SamplerBorderColor) {
291        unsafe {
292            msg_send_1::<(), SamplerBorderColor>(self.as_ptr(), sel!(setBorderColor:), color);
293        }
294    }
295
296    // =========================================================================
297    // LOD Properties
298    // =========================================================================
299
300    /// Get the minimum LOD clamp.
301    ///
302    /// C++ equivalent: `float lodMinClamp() const`
303    #[inline]
304    pub fn lod_min_clamp(&self) -> f32 {
305        unsafe { msg_send_0(self.as_ptr(), sel!(lodMinClamp)) }
306    }
307
308    /// Set the minimum LOD clamp.
309    ///
310    /// C++ equivalent: `void setLodMinClamp(float)`
311    #[inline]
312    pub fn set_lod_min_clamp(&self, clamp: f32) {
313        unsafe {
314            msg_send_1::<(), f32>(self.as_ptr(), sel!(setLodMinClamp:), clamp);
315        }
316    }
317
318    /// Get the maximum LOD clamp.
319    ///
320    /// C++ equivalent: `float lodMaxClamp() const`
321    #[inline]
322    pub fn lod_max_clamp(&self) -> f32 {
323        unsafe { msg_send_0(self.as_ptr(), sel!(lodMaxClamp)) }
324    }
325
326    /// Set the maximum LOD clamp.
327    ///
328    /// C++ equivalent: `void setLodMaxClamp(float)`
329    #[inline]
330    pub fn set_lod_max_clamp(&self, clamp: f32) {
331        unsafe {
332            msg_send_1::<(), f32>(self.as_ptr(), sel!(setLodMaxClamp:), clamp);
333        }
334    }
335
336    /// Get the LOD bias.
337    ///
338    /// C++ equivalent: `float lodBias() const`
339    #[inline]
340    pub fn lod_bias(&self) -> f32 {
341        unsafe { msg_send_0(self.as_ptr(), sel!(lodBias)) }
342    }
343
344    /// Set the LOD bias.
345    ///
346    /// C++ equivalent: `void setLodBias(float)`
347    #[inline]
348    pub fn set_lod_bias(&self, bias: f32) {
349        unsafe {
350            msg_send_1::<(), f32>(self.as_ptr(), sel!(setLodBias:), bias);
351        }
352    }
353
354    /// Get whether LOD averaging is enabled.
355    ///
356    /// C++ equivalent: `bool lodAverage() const`
357    #[inline]
358    pub fn lod_average(&self) -> bool {
359        unsafe { msg_send_0(self.as_ptr(), sel!(lodAverage)) }
360    }
361
362    /// Set whether LOD averaging is enabled.
363    ///
364    /// C++ equivalent: `void setLodAverage(bool)`
365    #[inline]
366    pub fn set_lod_average(&self, average: bool) {
367        unsafe {
368            msg_send_1::<(), bool>(self.as_ptr(), sel!(setLodAverage:), average);
369        }
370    }
371
372    // =========================================================================
373    // Anisotropy
374    // =========================================================================
375
376    /// Get the maximum anisotropy.
377    ///
378    /// C++ equivalent: `NS::UInteger maxAnisotropy() const`
379    #[inline]
380    pub fn max_anisotropy(&self) -> UInteger {
381        unsafe { msg_send_0(self.as_ptr(), sel!(maxAnisotropy)) }
382    }
383
384    /// Set the maximum anisotropy.
385    ///
386    /// C++ equivalent: `void setMaxAnisotropy(NS::UInteger)`
387    #[inline]
388    pub fn set_max_anisotropy(&self, max: UInteger) {
389        unsafe {
390            msg_send_1::<(), UInteger>(self.as_ptr(), sel!(setMaxAnisotropy:), max);
391        }
392    }
393
394    // =========================================================================
395    // Compare Function
396    // =========================================================================
397
398    /// Get the compare function.
399    ///
400    /// C++ equivalent: `CompareFunction compareFunction() const`
401    #[inline]
402    pub fn compare_function(&self) -> CompareFunction {
403        unsafe { msg_send_0(self.as_ptr(), sel!(compareFunction)) }
404    }
405
406    /// Set the compare function.
407    ///
408    /// C++ equivalent: `void setCompareFunction(CompareFunction)`
409    #[inline]
410    pub fn set_compare_function(&self, func: CompareFunction) {
411        unsafe {
412            msg_send_1::<(), CompareFunction>(self.as_ptr(), sel!(setCompareFunction:), func);
413        }
414    }
415
416    // =========================================================================
417    // Reduction Mode
418    // =========================================================================
419
420    /// Get the reduction mode.
421    ///
422    /// C++ equivalent: `SamplerReductionMode reductionMode() const`
423    #[inline]
424    pub fn reduction_mode(&self) -> SamplerReductionMode {
425        unsafe { msg_send_0(self.as_ptr(), sel!(reductionMode)) }
426    }
427
428    /// Set the reduction mode.
429    ///
430    /// C++ equivalent: `void setReductionMode(SamplerReductionMode)`
431    #[inline]
432    pub fn set_reduction_mode(&self, mode: SamplerReductionMode) {
433        unsafe {
434            msg_send_1::<(), SamplerReductionMode>(self.as_ptr(), sel!(setReductionMode:), mode);
435        }
436    }
437
438    // =========================================================================
439    // Normalized Coordinates
440    // =========================================================================
441
442    /// Check if normalized coordinates are used.
443    ///
444    /// C++ equivalent: `bool normalizedCoordinates() const`
445    #[inline]
446    pub fn normalized_coordinates(&self) -> bool {
447        unsafe { msg_send_0(self.as_ptr(), sel!(normalizedCoordinates)) }
448    }
449
450    /// Set whether normalized coordinates are used.
451    ///
452    /// C++ equivalent: `void setNormalizedCoordinates(bool)`
453    #[inline]
454    pub fn set_normalized_coordinates(&self, normalized: bool) {
455        unsafe {
456            msg_send_1::<(), bool>(self.as_ptr(), sel!(setNormalizedCoordinates:), normalized);
457        }
458    }
459
460    // =========================================================================
461    // Label
462    // =========================================================================
463
464    /// Get the label.
465    ///
466    /// C++ equivalent: `NS::String* label() const`
467    pub fn label(&self) -> Option<String> {
468        unsafe {
469            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(label));
470            if ptr.is_null() {
471                return None;
472            }
473            let utf8_ptr: *const std::ffi::c_char =
474                mtl_sys::msg_send_0(ptr as *const c_void, sel!(UTF8String));
475            if utf8_ptr.is_null() {
476                return None;
477            }
478            let c_str = std::ffi::CStr::from_ptr(utf8_ptr);
479            Some(c_str.to_string_lossy().into_owned())
480        }
481    }
482
483    /// Set the label.
484    ///
485    /// C++ equivalent: `void setLabel(const NS::String*)`
486    pub fn set_label(&self, label: &str) {
487        if let Some(ns_label) = mtl_foundation::String::from_str(label) {
488            unsafe {
489                msg_send_1::<(), *const c_void>(self.as_ptr(), sel!(setLabel:), ns_label.as_ptr());
490            }
491        }
492    }
493
494    /// Check if LOD average is supported.
495    ///
496    /// C++ equivalent: `bool supportArgumentBuffers() const`
497    #[inline]
498    pub fn support_argument_buffers(&self) -> bool {
499        unsafe { msg_send_0(self.as_ptr(), sel!(supportArgumentBuffers)) }
500    }
501
502    /// Set whether argument buffers are supported.
503    ///
504    /// C++ equivalent: `void setSupportArgumentBuffers(bool)`
505    #[inline]
506    pub fn set_support_argument_buffers(&self, support: bool) {
507        unsafe {
508            msg_send_1::<(), bool>(self.as_ptr(), sel!(setSupportArgumentBuffers:), support);
509        }
510    }
511}
512
513impl Default for SamplerDescriptor {
514    fn default() -> Self {
515        Self::new().expect("failed to create sampler descriptor")
516    }
517}
518
519impl Clone for SamplerDescriptor {
520    fn clone(&self) -> Self {
521        unsafe {
522            let ptr: *mut c_void = msg_send_0(self.as_ptr(), sel!(copy));
523            Self::from_raw(ptr).expect("failed to copy sampler descriptor")
524        }
525    }
526}
527
528impl Drop for SamplerDescriptor {
529    fn drop(&mut self) {
530        unsafe {
531            msg_send_0::<()>(self.as_ptr(), sel!(release));
532        }
533    }
534}
535
536impl Referencing for SamplerDescriptor {
537    #[inline]
538    fn as_ptr(&self) -> *const c_void {
539        self.0.as_ptr()
540    }
541}
542
543unsafe impl Send for SamplerDescriptor {}
544unsafe impl Sync for SamplerDescriptor {}
545
546impl std::fmt::Debug for SamplerDescriptor {
547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548        f.debug_struct("SamplerDescriptor")
549            .field("min_filter", &self.min_filter())
550            .field("mag_filter", &self.mag_filter())
551            .field("mip_filter", &self.mip_filter())
552            .field("label", &self.label())
553            .finish()
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_sampler_state_size() {
563        assert_eq!(
564            std::mem::size_of::<SamplerState>(),
565            std::mem::size_of::<*mut c_void>()
566        );
567    }
568
569    #[test]
570    fn test_sampler_descriptor_creation() {
571        let desc = SamplerDescriptor::new();
572        assert!(desc.is_some());
573    }
574}