spirv_std_macros/
lib.rs

1// FIXME(eddyb) update/review these lints.
2//
3// BEGIN - Embark standard lints v0.4
4// do not change or add/remove here, but one can add exceptions after this section
5// for more info see: <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
6#![deny(unsafe_code)]
7#![warn(
8    clippy::all,
9    clippy::await_holding_lock,
10    clippy::char_lit_as_u8,
11    clippy::checked_conversions,
12    clippy::dbg_macro,
13    clippy::debug_assert_with_mut_call,
14    clippy::doc_markdown,
15    clippy::empty_enum,
16    clippy::enum_glob_use,
17    clippy::exit,
18    clippy::expl_impl_clone_on_copy,
19    clippy::explicit_deref_methods,
20    clippy::explicit_into_iter_loop,
21    clippy::fallible_impl_from,
22    clippy::filter_map_next,
23    clippy::float_cmp_const,
24    clippy::fn_params_excessive_bools,
25    clippy::if_let_mutex,
26    clippy::implicit_clone,
27    clippy::imprecise_flops,
28    clippy::inefficient_to_string,
29    clippy::invalid_upcast_comparisons,
30    clippy::large_types_passed_by_value,
31    clippy::let_unit_value,
32    clippy::linkedlist,
33    clippy::lossy_float_literal,
34    clippy::macro_use_imports,
35    clippy::manual_ok_or,
36    clippy::map_err_ignore,
37    clippy::map_flatten,
38    clippy::map_unwrap_or,
39    clippy::match_same_arms,
40    clippy::match_wildcard_for_single_variants,
41    clippy::mem_forget,
42    clippy::mut_mut,
43    clippy::mutex_integer,
44    clippy::needless_borrow,
45    clippy::needless_continue,
46    clippy::option_option,
47    clippy::path_buf_push_overwrite,
48    clippy::ptr_as_ptr,
49    clippy::ref_option_ref,
50    clippy::rest_pat_in_fully_bound_structs,
51    clippy::same_functions_in_if_condition,
52    clippy::semicolon_if_nothing_returned,
53    clippy::string_add_assign,
54    clippy::string_add,
55    clippy::string_lit_as_bytes,
56    clippy::string_to_string,
57    clippy::todo,
58    clippy::trait_duplication_in_bounds,
59    clippy::unimplemented,
60    clippy::unnested_or_patterns,
61    clippy::unused_self,
62    clippy::useless_transmute,
63    clippy::verbose_file_reads,
64    clippy::zero_sized_map_values,
65    future_incompatible,
66    nonstandard_style,
67    rust_2018_idioms
68)]
69// END - Embark standard lints v0.4
70// crate-specific exceptions:
71// #![allow()]
72#![doc = include_str!("../README.md")]
73
74mod debug_printf;
75mod image;
76mod sample_param_permutations;
77mod scalar_or_vector_composite;
78
79use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
80use proc_macro::TokenStream;
81use proc_macro2::{Delimiter, Group, Ident, TokenTree};
82use quote::{ToTokens, TokenStreamExt, format_ident, quote};
83use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
84
85/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
86/// `spirv_std::image::Image<...>` type.
87///
88/// The grammar for the macro is as follows:
89///
90/// ```rust,ignore
91/// Image!(
92///     <dimensionality>,
93///     <type=...|format=...>,
94///     [sampled[=<true|false>],]
95///     [multisampled[=<true|false>],]
96///     [arrayed[=<true|false>],]
97///     [depth[=<true|false>],]
98/// )
99/// ```
100///
101/// `=true` can be omitted as shorthand - e.g. `sampled` is short for `sampled=true`.
102///
103/// A basic example looks like this:
104/// ```rust,ignore
105/// #[spirv(vertex)]
106/// fn main(#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled)) {}
107/// ```
108///
109/// ## Arguments
110///
111/// - `dimensionality` — Dimensionality of an image.
112///   Accepted values: `1D`, `2D`, `3D`, `rect`, `cube`, `subpass`.
113/// - `type` — The sampled type of an image, mutually exclusive with `format`,
114///   when set the image format is unknown.
115///   Accepted values: `f32`, `f64`, `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`.
116/// - `format` — The image format of the image, mutually exclusive with `type`.
117///   Accepted values: Snake case versions of [`ImageFormat`] variants, e.g. `rgba32f`,
118///   `rgba8_snorm`.
119/// - `sampled` — Whether it is known that the image will be used with a sampler.
120///   Accepted values: `true` or `false`. Default: `unknown`.
121/// - `multisampled` — Whether the image contains multisampled content.
122///   Accepted values: `true` or `false`. Default: `false`.
123/// - `arrayed` — Whether the image contains arrayed content.
124///   Accepted values: `true` or `false`. Default: `false`.
125/// - `depth` — Whether it is known that the image is a depth image.
126///   Accepted values: `true` or `false`. Default: `unknown`.
127///
128/// [`ImageFormat`]: spirv_std_types::image_params::ImageFormat
129///
130/// Keep in mind that `sampled` here is a different concept than the `SampledImage` type:
131/// `sampled=true` means that this image requires a sampler to be able to access, while the
132/// `SampledImage` type bundles that sampler together with the image into a single type (e.g.
133/// `sampler2D` in GLSL, vs. `texture2D`).
134#[proc_macro]
135// The `Image` is supposed to be used in the type position, which
136// uses `PascalCase`.
137#[allow(nonstandard_style)]
138pub fn Image(item: TokenStream) -> TokenStream {
139    let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
140
141    output.into()
142}
143
144/// Replaces all (nested) occurrences of the `#[spirv(..)]` attribute with
145/// `#[cfg_attr(target_arch="spirv", rust_gpu::spirv(..))]`.
146#[proc_macro_attribute]
147pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
148    let spirv = format_ident!("{}", &spirv_attr_with_version());
149
150    // prepend with #[rust_gpu::spirv(..)]
151    let attr: proc_macro2::TokenStream = attr.into();
152    let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
153
154    let item: proc_macro2::TokenStream = item.into();
155    for tt in item {
156        match tt {
157            TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
158                let mut group_tokens = proc_macro2::TokenStream::new();
159                let mut last_token_hashtag = false;
160                for tt in group.stream() {
161                    let is_token_hashtag =
162                        matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
163                    match tt {
164                        TokenTree::Group(group)
165                            if group.delimiter() == Delimiter::Bracket
166                                && last_token_hashtag
167                                && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
168                        {
169                            // group matches [spirv ...]
170                            // group stream doesn't include the brackets
171                            let inner = group
172                                .stream()
173                                .into_iter()
174                                .skip(1)
175                                .collect::<proc_macro2::TokenStream>();
176                            group_tokens.extend(
177                                quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
178                            );
179                        }
180                        _ => group_tokens.append(tt),
181                    }
182                    last_token_hashtag = is_token_hashtag;
183                }
184                let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
185                out.set_span(group.span());
186                tokens.append(out);
187            }
188            _ => tokens.append(tt),
189        }
190    }
191    tokens.into()
192}
193
194/// For testing only! Is not reexported in `spirv-std`, but reachable via
195/// `spirv_std::macros::spirv_recursive_for_testing`.
196///
197/// May be more expensive than plain `spirv`, since we're checking a lot more symbols. So I've opted to
198/// have this be a separate macro, instead of modifying the standard `spirv` one.
199#[proc_macro_attribute]
200pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
201    fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
202        let mut last_token_hashtag = false;
203        stream.into_iter().map(|tt| {
204            let mut is_token_hashtag = false;
205            let out = match tt {
206                TokenTree::Group(group)
207                if group.delimiter() == Delimiter::Bracket
208                    && last_token_hashtag
209                    && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
210                    {
211                        // group matches [spirv ...]
212                        // group stream doesn't include the brackets
213                        let inner = group
214                            .stream()
215                            .into_iter()
216                            .skip(1)
217                            .collect::<proc_macro2::TokenStream>();
218                        quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
219                    },
220                TokenTree::Group(group) => {
221                    let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
222                    out.set_span(group.span());
223                    TokenTree::Group(out).into()
224                },
225                TokenTree::Punct(punct) => {
226                    is_token_hashtag = punct.as_char() == '#';
227                    TokenTree::Punct(punct).into()
228                }
229                tt => tt.into(),
230            };
231            last_token_hashtag = is_token_hashtag;
232            out
233        }).collect()
234    }
235
236    let attr: proc_macro2::TokenStream = attr.into();
237    let item: proc_macro2::TokenStream = item.into();
238
239    // prepend with #[rust_gpu::spirv(..)]
240    let spirv = format_ident!("{}", &spirv_attr_with_version());
241    let inner = recurse(&spirv, item);
242    quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
243}
244
245/// Marks a function as runnable only on the GPU, and will panic on
246/// CPU platforms.
247#[proc_macro_attribute]
248pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
249    let syn::ItemFn {
250        attrs,
251        vis,
252        sig,
253        block,
254    } = syn::parse_macro_input!(item as syn::ItemFn);
255
256    let fn_name = sig.ident.clone();
257
258    let sig_cpu = syn::Signature {
259        abi: None,
260        ..sig.clone()
261    };
262
263    let output = quote::quote! {
264        // Don't warn on unused arguments on the CPU side.
265        #[cfg(not(target_arch="spirv"))]
266        #[allow(unused_variables)]
267        #(#attrs)* #vis #sig_cpu {
268            unimplemented!(
269                concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms.")
270            )
271        }
272
273        #[cfg(target_arch="spirv")]
274        #(#attrs)* #vis #sig {
275            #block
276        }
277    };
278
279    output.into()
280}
281
282/// Print a formatted string using the debug printf extension.
283///
284/// Examples:
285///
286/// ```rust,ignore
287/// debug_printf!("uv: %v2f\n", uv);
288/// debug_printf!("pos.x: %f, pos.z: %f, int: %i\n", pos.x, pos.z, int);
289/// ```
290///
291/// See <https://github.com/KhronosGroup/Vulkan-ValidationLayers/blob/main/docs/debug_printf.md#debug-printf-format-string> for formatting rules.
292#[proc_macro]
293pub fn debug_printf(input: TokenStream) -> TokenStream {
294    debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
295}
296
297/// Similar to `debug_printf` but appends a newline to the format string.
298#[proc_macro]
299pub fn debug_printfln(input: TokenStream) -> TokenStream {
300    let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
301    input.format_string.push('\n');
302    debug_printf_inner(input)
303}
304
305/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
306/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
307/// of each function must be named `params`, its type will be rewritten. Relevant generic
308/// arguments are added to the impl generics.
309/// See `SAMPLE_PARAM_GENERICS` for a list of names you cannot use as generic arguments.
310#[proc_macro_attribute]
311#[doc(hidden)]
312pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
313    sample_param_permutations::gen_sample_param_permutations(item)
314}
315
316#[proc_macro_derive(ScalarComposite)]
317pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream {
318    scalar_or_vector_composite::derive(item.into())
319        .unwrap_or_else(syn::Error::into_compile_error)
320        .into()
321}