spirv_std_macros/
lib.rs

1// FIXME(eddyb) update/review these lints.
2//
3// BEGIN - Embark standard lints v0.4
4// do not change or add/remove here, but one can add exceptions after this section
5// for more info see: <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
6#![deny(unsafe_code)]
7#![warn(
8    clippy::all,
9    clippy::await_holding_lock,
10    clippy::char_lit_as_u8,
11    clippy::checked_conversions,
12    clippy::dbg_macro,
13    clippy::debug_assert_with_mut_call,
14    clippy::doc_markdown,
15    clippy::empty_enum,
16    clippy::enum_glob_use,
17    clippy::exit,
18    clippy::expl_impl_clone_on_copy,
19    clippy::explicit_deref_methods,
20    clippy::explicit_into_iter_loop,
21    clippy::fallible_impl_from,
22    clippy::filter_map_next,
23    clippy::float_cmp_const,
24    clippy::fn_params_excessive_bools,
25    clippy::if_let_mutex,
26    clippy::implicit_clone,
27    clippy::imprecise_flops,
28    clippy::inefficient_to_string,
29    clippy::invalid_upcast_comparisons,
30    clippy::large_types_passed_by_value,
31    clippy::let_unit_value,
32    clippy::linkedlist,
33    clippy::lossy_float_literal,
34    clippy::macro_use_imports,
35    clippy::manual_ok_or,
36    clippy::map_err_ignore,
37    clippy::map_flatten,
38    clippy::map_unwrap_or,
39    clippy::match_same_arms,
40    clippy::match_wildcard_for_single_variants,
41    clippy::mem_forget,
42    clippy::mut_mut,
43    clippy::mutex_integer,
44    clippy::needless_borrow,
45    clippy::needless_continue,
46    clippy::option_option,
47    clippy::path_buf_push_overwrite,
48    clippy::ptr_as_ptr,
49    clippy::ref_option_ref,
50    clippy::rest_pat_in_fully_bound_structs,
51    clippy::same_functions_in_if_condition,
52    clippy::semicolon_if_nothing_returned,
53    clippy::string_add_assign,
54    clippy::string_add,
55    clippy::string_lit_as_bytes,
56    clippy::string_to_string,
57    clippy::todo,
58    clippy::trait_duplication_in_bounds,
59    clippy::unimplemented,
60    clippy::unnested_or_patterns,
61    clippy::unused_self,
62    clippy::useless_transmute,
63    clippy::verbose_file_reads,
64    clippy::zero_sized_map_values,
65    future_incompatible,
66    nonstandard_style,
67    rust_2018_idioms
68)]
69// END - Embark standard lints v0.4
70// crate-specific exceptions:
71// #![allow()]
72#![doc = include_str!("../README.md")]
73
74mod image;
75
76use proc_macro::TokenStream;
77use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
78
79use syn::{ImplItemFn, visit_mut::VisitMut};
80
81use quote::{ToTokens, TokenStreamExt, format_ident, quote};
82use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
83use std::fmt::Write;
84
85/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
86/// `spirv_std::image::Image<...>` type.
87///
88/// The grammar for the macro is as follows:
89///
90/// ```rust,ignore
91/// Image!(
92///     <dimensionality>,
93///     <type=...|format=...>,
94///     [sampled[=<true|false>],]
95///     [multisampled[=<true|false>],]
96///     [arrayed[=<true|false>],]
97///     [depth[=<true|false>],]
98/// )
99/// ```
100///
101/// `=true` can be omitted as shorthand - e.g. `sampled` is short for `sampled=true`.
102///
103/// A basic example looks like this:
104/// ```rust,ignore
105/// #[spirv(vertex)]
106/// fn main(#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled)) {}
107/// ```
108///
109/// ## Arguments
110///
111/// - `dimensionality` — Dimensionality of an image.
112///   Accepted values: `1D`, `2D`, `3D`, `rect`, `cube`, `subpass`.
113/// - `type` — The sampled type of an image, mutually exclusive with `format`,
114///   when set the image format is unknown.
115///   Accepted values: `f32`, `f64`, `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`.
116/// - `format` — The image format of the image, mutually exclusive with `type`.
117///   Accepted values: Snake case versions of [`ImageFormat`] variants, e.g. `rgba32f`,
118///   `rgba8_snorm`.
119/// - `sampled` — Whether it is known that the image will be used with a sampler.
120///   Accepted values: `true` or `false`. Default: `unknown`.
121/// - `multisampled` — Whether the image contains multisampled content.
122///   Accepted values: `true` or `false`. Default: `false`.
123/// - `arrayed` — Whether the image contains arrayed content.
124///   Accepted values: `true` or `false`. Default: `false`.
125/// - `depth` — Whether it is known that the image is a depth image.
126///   Accepted values: `true` or `false`. Default: `unknown`.
127///
128/// [`ImageFormat`]: spirv_std_types::image_params::ImageFormat
129///
130/// Keep in mind that `sampled` here is a different concept than the `SampledImage` type:
131/// `sampled=true` means that this image requires a sampler to be able to access, while the
132/// `SampledImage` type bundles that sampler together with the image into a single type (e.g.
133/// `sampler2D` in GLSL, vs. `texture2D`).
134#[proc_macro]
135// The `Image` is supposed to be used in the type position, which
136// uses `PascalCase`.
137#[allow(nonstandard_style)]
138pub fn Image(item: TokenStream) -> TokenStream {
139    let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
140
141    output.into()
142}
143
144/// Replaces all (nested) occurrences of the `#[spirv(..)]` attribute with
145/// `#[cfg_attr(target_arch="spirv", rust_gpu::spirv(..))]`.
146#[proc_macro_attribute]
147pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
148    let spirv = format_ident!("{}", &spirv_attr_with_version());
149
150    // prepend with #[rust_gpu::spirv(..)]
151    let attr: proc_macro2::TokenStream = attr.into();
152    let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
153
154    let item: proc_macro2::TokenStream = item.into();
155    for tt in item {
156        match tt {
157            TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
158                let mut group_tokens = proc_macro2::TokenStream::new();
159                let mut last_token_hashtag = false;
160                for tt in group.stream() {
161                    let is_token_hashtag =
162                        matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
163                    match tt {
164                        TokenTree::Group(group)
165                            if group.delimiter() == Delimiter::Bracket
166                                && last_token_hashtag
167                                && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
168                        {
169                            // group matches [spirv ...]
170                            // group stream doesn't include the brackets
171                            let inner = group
172                                .stream()
173                                .into_iter()
174                                .skip(1)
175                                .collect::<proc_macro2::TokenStream>();
176                            group_tokens.extend(
177                                quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
178                            );
179                        }
180                        _ => group_tokens.append(tt),
181                    }
182                    last_token_hashtag = is_token_hashtag;
183                }
184                let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
185                out.set_span(group.span());
186                tokens.append(out);
187            }
188            _ => tokens.append(tt),
189        }
190    }
191    tokens.into()
192}
193
194/// For testing only! Is not reexported in `spirv-std`, but reachable via
195/// `spirv_std::macros::spirv_recursive_for_testing`.
196///
197/// May be more expensive than plain `spirv`, since we're checking a lot more symbols. So I've opted to
198/// have this be a separate macro, instead of modifying the standard `spirv` one.
199#[proc_macro_attribute]
200pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
201    fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
202        let mut last_token_hashtag = false;
203        stream.into_iter().map(|tt| {
204            let mut is_token_hashtag = false;
205            let out = match tt {
206                TokenTree::Group(group)
207                if group.delimiter() == Delimiter::Bracket
208                    && last_token_hashtag
209                    && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
210                    {
211                        // group matches [spirv ...]
212                        // group stream doesn't include the brackets
213                        let inner = group
214                            .stream()
215                            .into_iter()
216                            .skip(1)
217                            .collect::<proc_macro2::TokenStream>();
218                        quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
219                    },
220                TokenTree::Group(group) => {
221                    let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
222                    out.set_span(group.span());
223                    TokenTree::Group(out).into()
224                },
225                TokenTree::Punct(punct) => {
226                    is_token_hashtag = punct.as_char() == '#';
227                    TokenTree::Punct(punct).into()
228                }
229                tt => tt.into(),
230            };
231            last_token_hashtag = is_token_hashtag;
232            out
233        }).collect()
234    }
235
236    let attr: proc_macro2::TokenStream = attr.into();
237    let item: proc_macro2::TokenStream = item.into();
238
239    // prepend with #[rust_gpu::spirv(..)]
240    let spirv = format_ident!("{}", &spirv_attr_with_version());
241    let inner = recurse(&spirv, item);
242    quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
243}
244
245/// Marks a function as runnable only on the GPU, and will panic on
246/// CPU platforms.
247#[proc_macro_attribute]
248pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
249    let syn::ItemFn {
250        attrs,
251        vis,
252        sig,
253        block,
254    } = syn::parse_macro_input!(item as syn::ItemFn);
255
256    let fn_name = sig.ident.clone();
257
258    let sig_cpu = syn::Signature {
259        abi: None,
260        ..sig.clone()
261    };
262
263    let output = quote::quote! {
264        // Don't warn on unused arguments on the CPU side.
265        #[cfg(not(target_arch="spirv"))]
266        #[allow(unused_variables)]
267        #(#attrs)* #vis #sig_cpu {
268            unimplemented!(concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms."))
269        }
270
271        #[cfg(target_arch="spirv")]
272        #(#attrs)* #vis #sig {
273            #block
274        }
275    };
276
277    output.into()
278}
279
280/// Print a formatted string using the debug printf extension.
281///
282/// Examples:
283///
284/// ```rust,ignore
285/// debug_printf!("uv: %v2f\n", uv);
286/// debug_printf!("pos.x: %f, pos.z: %f, int: %i\n", pos.x, pos.z, int);
287/// ```
288///
289/// See <https://github.com/KhronosGroup/Vulkan-ValidationLayers/blob/main/docs/debug_printf.md#debug-printf-format-string> for formatting rules.
290#[proc_macro]
291pub fn debug_printf(input: TokenStream) -> TokenStream {
292    debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
293}
294
295/// Similar to `debug_printf` but appends a newline to the format string.
296#[proc_macro]
297pub fn debug_printfln(input: TokenStream) -> TokenStream {
298    let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
299    input.format_string.push('\n');
300    debug_printf_inner(input)
301}
302
303struct DebugPrintfInput {
304    span: Span,
305    format_string: String,
306    variables: Vec<syn::Expr>,
307}
308
309impl syn::parse::Parse for DebugPrintfInput {
310    fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
311        let span = input.span();
312
313        if input.is_empty() {
314            return Ok(Self {
315                span,
316                format_string: Default::default(),
317                variables: Default::default(),
318            });
319        }
320
321        let format_string = input.parse::<syn::LitStr>()?;
322        if !input.is_empty() {
323            input.parse::<syn::token::Comma>()?;
324        }
325        let variables =
326            syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
327
328        Ok(Self {
329            span,
330            format_string: format_string.value(),
331            variables: variables.into_iter().collect(),
332        })
333    }
334}
335
336fn parsing_error(message: &str, span: Span) -> TokenStream {
337    syn::Error::new(span, message).to_compile_error().into()
338}
339
340enum FormatType {
341    Scalar {
342        ty: proc_macro2::TokenStream,
343    },
344    Vector {
345        ty: proc_macro2::TokenStream,
346        width: usize,
347    },
348}
349
350fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
351    let DebugPrintfInput {
352        format_string,
353        variables,
354        span,
355    } = input;
356
357    fn map_specifier_to_type(
358        specifier: char,
359        chars: &mut std::str::Chars<'_>,
360    ) -> Option<proc_macro2::TokenStream> {
361        let mut peekable = chars.peekable();
362
363        Some(match specifier {
364            'd' | 'i' => quote::quote! { i32 },
365            'o' | 'x' | 'X' => quote::quote! { u32 },
366            'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
367            'u' => {
368                if matches!(peekable.peek(), Some('l')) {
369                    chars.next();
370                    quote::quote! { u64 }
371                } else {
372                    quote::quote! { u32 }
373                }
374            }
375            'l' => {
376                if matches!(peekable.peek(), Some('u' | 'x')) {
377                    chars.next();
378                    quote::quote! { u64 }
379                } else {
380                    return None;
381                }
382            }
383            _ => return None,
384        })
385    }
386
387    let mut chars = format_string.chars();
388    let mut format_arguments = Vec::new();
389
390    while let Some(mut ch) = chars.next() {
391        if ch == '%' {
392            ch = match chars.next() {
393                Some('%') => continue,
394                None => return parsing_error("Unterminated format specifier", span),
395                Some(ch) => ch,
396            };
397
398            let mut has_precision = false;
399
400            while ch.is_ascii_digit() {
401                ch = match chars.next() {
402                    Some(ch) => ch,
403                    None => {
404                        return parsing_error(
405                            "Unterminated format specifier: missing type after precision",
406                            span,
407                        );
408                    }
409                };
410
411                has_precision = true;
412            }
413
414            if has_precision && ch == '.' {
415                ch = match chars.next() {
416                    Some(ch) => ch,
417                    None => {
418                        return parsing_error(
419                            "Unterminated format specifier: missing type after decimal point",
420                            span,
421                        );
422                    }
423                };
424
425                while ch.is_ascii_digit() {
426                    ch = match chars.next() {
427                        Some(ch) => ch,
428                        None => {
429                            return parsing_error(
430                                "Unterminated format specifier: missing type after fraction precision",
431                                span,
432                            );
433                        }
434                    };
435                }
436            }
437
438            if ch == 'v' {
439                let width = match chars.next() {
440                    Some('2') => 2,
441                    Some('3') => 3,
442                    Some('4') => 4,
443                    Some(ch) => {
444                        return parsing_error(&format!("Invalid width for vector: {ch}"), span);
445                    }
446                    None => return parsing_error("Missing vector dimensions specifier", span),
447                };
448
449                ch = match chars.next() {
450                    Some(ch) => ch,
451                    None => return parsing_error("Missing vector type specifier", span),
452                };
453
454                let ty = match map_specifier_to_type(ch, &mut chars) {
455                    Some(ty) => ty,
456                    _ => {
457                        return parsing_error(
458                            &format!("Unrecognised vector type specifier: '{ch}'"),
459                            span,
460                        );
461                    }
462                };
463
464                format_arguments.push(FormatType::Vector { ty, width });
465            } else {
466                let ty = match map_specifier_to_type(ch, &mut chars) {
467                    Some(ty) => ty,
468                    _ => {
469                        return parsing_error(
470                            &format!("Unrecognised format specifier: '{ch}'"),
471                            span,
472                        );
473                    }
474                };
475
476                format_arguments.push(FormatType::Scalar { ty });
477            }
478        }
479    }
480
481    if format_arguments.len() != variables.len() {
482        return syn::Error::new(
483            span,
484            format!(
485                "{} % arguments were found, but {} variables were given",
486                format_arguments.len(),
487                variables.len()
488            ),
489        )
490        .to_compile_error()
491        .into();
492    }
493
494    let mut variable_idents = String::new();
495    let mut input_registers = Vec::new();
496    let mut op_loads = Vec::new();
497
498    for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
499    {
500        let ident = quote::format_ident!("_{}", i);
501
502        let _ = write!(variable_idents, "%{ident} ");
503
504        let assert_fn = match format_argument {
505            FormatType::Scalar { ty } => {
506                quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
507            }
508            FormatType::Vector { ty, width } => {
509                quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
510            }
511        };
512
513        input_registers.push(quote::quote! {
514            #ident = in(reg) &#assert_fn(#variable),
515        });
516
517        let op_load = format!("%{ident} = OpLoad _ {{{ident}}}");
518
519        op_loads.push(quote::quote! {
520            #op_load,
521        });
522    }
523
524    let input_registers = input_registers
525        .into_iter()
526        .collect::<proc_macro2::TokenStream>();
527    let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
528    // Escapes the '{' and '}' characters in the format string.
529    // Since the `asm!` macro expects '{' '}' to surround its arguments, we have to use '{{' and '}}' instead.
530    // The `asm!` macro will then later turn them back into '{' and '}'.
531    let format_string = format_string.replace('{', "{{").replace('}', "}}");
532
533    let op_string = format!("%string = OpString {format_string:?}");
534
535    let output = quote::quote! {
536        ::core::arch::asm!(
537            "%void = OpTypeVoid",
538            #op_string,
539            "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
540            #op_loads
541            concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
542            #input_registers
543        )
544    };
545
546    output.into()
547}
548
549const SAMPLE_PARAM_COUNT: usize = 4;
550const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
551const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
552const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
553const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
554const SAMPLE_PARAM_GRAD_INDEX: usize = 2; // Grad requires some special handling because it uses 2 arguments
555const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; // which params require the use of ExplicitLod rather than ImplicitLod
556
557fn is_grad(i: usize) -> bool {
558    i == SAMPLE_PARAM_GRAD_INDEX
559}
560
561struct SampleImplRewriter(usize, syn::Type);
562
563impl SampleImplRewriter {
564    pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
565        let mut new_impl = f.clone();
566        let mut ty_str = String::from("SampleParams<");
567
568        // based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
569        // example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
570        for i in 0..SAMPLE_PARAM_COUNT {
571            if mask & (1 << i) != 0 {
572                new_impl.generics.params.push(syn::GenericParam::Type(
573                    syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
574                ));
575                ty_str.push_str("SomeTy<");
576                ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
577                ty_str.push('>');
578            } else {
579                ty_str.push_str("NoneTy");
580            }
581            ty_str.push(',');
582        }
583        ty_str.push('>');
584        let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
585
586        // use the type to insert it into the generic argument of the trait we're implementing
587        // e.g., `ImageWithMethods<Dummy>` becomes `ImageWithMethods<SampleParams<SomeTy<B>, NoneTy, NoneTy>>`
588        if let Some(t) = &mut new_impl.trait_
589            && let syn::PathArguments::AngleBracketed(a) =
590                &mut t.1.segments.last_mut().unwrap().arguments
591            && let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
592        {
593            *t = ty.clone();
594        }
595
596        // rewrite the implemented functions
597        SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
598        new_impl
599    }
600
601    // generates an operands string for use in the assembly, e.g. "Bias %bias Lod %lod", based on the mask
602    #[allow(clippy::needless_range_loop)]
603    fn get_operands(&self) -> String {
604        let mut op = String::new();
605        for i in 0..SAMPLE_PARAM_COUNT {
606            if self.0 & (1 << i) != 0 {
607                if is_grad(i) {
608                    op.push_str("Grad %grad_x %grad_y ");
609                } else {
610                    op.push_str(SAMPLE_PARAM_OPERANDS[i]);
611                    op.push_str(" %");
612                    op.push_str(SAMPLE_PARAM_NAMES[i]);
613                    op.push(' ');
614                }
615            }
616        }
617        op
618    }
619
620    // generates list of assembly loads for the data, e.g. "%bias = OpLoad _ {bias}", etc.
621    #[allow(clippy::needless_range_loop)]
622    fn add_loads(&self, t: &mut Vec<TokenTree>) {
623        for i in 0..SAMPLE_PARAM_COUNT {
624            if self.0 & (1 << i) != 0 {
625                if is_grad(i) {
626                    t.push(TokenTree::Literal(proc_macro2::Literal::string(
627                        "%grad_x = OpLoad _ {grad_x}",
628                    )));
629                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
630                        ',',
631                        proc_macro2::Spacing::Alone,
632                    )));
633                    t.push(TokenTree::Literal(proc_macro2::Literal::string(
634                        "%grad_y = OpLoad _ {grad_y}",
635                    )));
636                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
637                        ',',
638                        proc_macro2::Spacing::Alone,
639                    )));
640                } else {
641                    let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
642                    t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
643                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
644                        ',',
645                        proc_macro2::Spacing::Alone,
646                    )));
647                }
648            }
649        }
650    }
651
652    // generates list of register specifications, e.g. `bias = in(reg) &params.bias.0, ...` as separate tokens
653    #[allow(clippy::needless_range_loop)]
654    fn add_regs(&self, t: &mut Vec<TokenTree>) {
655        for i in 0..SAMPLE_PARAM_COUNT {
656            if self.0 & (1 << i) != 0 {
657                // HACK(eddyb) the extra `{...}` force the pointers to be to
658                // fresh variables holding value copies, instead of the originals,
659                // allowing `OpLoad _` inference to pick the appropriate type.
660                let s = if is_grad(i) {
661                    "grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
662                        .to_string()
663                } else {
664                    format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
665                };
666                let ts: proc_macro2::TokenStream = s.parse().unwrap();
667                t.extend(ts);
668            }
669        }
670    }
671}
672
673impl VisitMut for SampleImplRewriter {
674    fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
675        // rewrite the last parameter of this method to be of type `SampleParams<...>` we generated earlier
676        if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
677            *p.ty.as_mut() = self.1.clone();
678        }
679        syn::visit_mut::visit_impl_item_fn_mut(self, item);
680    }
681
682    fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
683        if m.path.is_ident("asm") {
684            // this is where the asm! block is manipulated
685            let t = m.tokens.clone();
686            let mut new_t = Vec::new();
687            let mut altered = false;
688
689            for tt in t {
690                match tt {
691                    TokenTree::Literal(l) => {
692                        if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
693                            // found a string literal
694                            let s = l.value();
695                            if s.contains("$PARAMS") {
696                                altered = true;
697                                // add load instructions before the sampling instruction
698                                self.add_loads(&mut new_t);
699                                // and insert image operands
700                                let s = s.replace("$PARAMS", &self.get_operands());
701                                let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
702                                    "ExplicitLod"
703                                } else {
704                                    "ImplicitLod "
705                                };
706                                let s = s.replace("$LOD", lod_type);
707
708                                new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
709                                    s.as_str(),
710                                )));
711                            } else {
712                                new_t.push(TokenTree::Literal(l.token()));
713                            }
714                        } else {
715                            new_t.push(TokenTree::Literal(l));
716                        }
717                    }
718                    _ => {
719                        new_t.push(tt);
720                    }
721                }
722            }
723
724            if altered {
725                // finally, add register specs
726                self.add_regs(&mut new_t);
727            }
728
729            // replace all tokens within the asm! block with our new list
730            m.tokens = new_t.into_iter().collect();
731        }
732    }
733}
734
735/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
736/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
737/// of each function must be named `params`, its type will be rewritten. Relevant generic
738/// arguments are added to the impl generics.
739/// See `SAMPLE_PARAM_GENERICS` for a list of names you cannot use as generic arguments.
740#[proc_macro_attribute]
741#[doc(hidden)]
742pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
743    let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
744    let mut fns = Vec::new();
745
746    for m in 1..(1 << SAMPLE_PARAM_COUNT) {
747        fns.push(SampleImplRewriter::rewrite(m, &item_impl));
748    }
749
750    // uncomment to output generated tokenstream to stdout
751    //println!("{}", quote! { #(#fns)* }.to_string());
752    quote! { #(#fns)* }.into()
753}