1#![deny(unsafe_code)]
7#![warn(
8 clippy::all,
9 clippy::await_holding_lock,
10 clippy::char_lit_as_u8,
11 clippy::checked_conversions,
12 clippy::dbg_macro,
13 clippy::debug_assert_with_mut_call,
14 clippy::doc_markdown,
15 clippy::empty_enum,
16 clippy::enum_glob_use,
17 clippy::exit,
18 clippy::expl_impl_clone_on_copy,
19 clippy::explicit_deref_methods,
20 clippy::explicit_into_iter_loop,
21 clippy::fallible_impl_from,
22 clippy::filter_map_next,
23 clippy::float_cmp_const,
24 clippy::fn_params_excessive_bools,
25 clippy::if_let_mutex,
26 clippy::implicit_clone,
27 clippy::imprecise_flops,
28 clippy::inefficient_to_string,
29 clippy::invalid_upcast_comparisons,
30 clippy::large_types_passed_by_value,
31 clippy::let_unit_value,
32 clippy::linkedlist,
33 clippy::lossy_float_literal,
34 clippy::macro_use_imports,
35 clippy::manual_ok_or,
36 clippy::map_err_ignore,
37 clippy::map_flatten,
38 clippy::map_unwrap_or,
39 clippy::match_same_arms,
40 clippy::match_wildcard_for_single_variants,
41 clippy::mem_forget,
42 clippy::mut_mut,
43 clippy::mutex_integer,
44 clippy::needless_borrow,
45 clippy::needless_continue,
46 clippy::option_option,
47 clippy::path_buf_push_overwrite,
48 clippy::ptr_as_ptr,
49 clippy::ref_option_ref,
50 clippy::rest_pat_in_fully_bound_structs,
51 clippy::same_functions_in_if_condition,
52 clippy::semicolon_if_nothing_returned,
53 clippy::string_add_assign,
54 clippy::string_add,
55 clippy::string_lit_as_bytes,
56 clippy::string_to_string,
57 clippy::todo,
58 clippy::trait_duplication_in_bounds,
59 clippy::unimplemented,
60 clippy::unnested_or_patterns,
61 clippy::unused_self,
62 clippy::useless_transmute,
63 clippy::verbose_file_reads,
64 clippy::zero_sized_map_values,
65 future_incompatible,
66 nonstandard_style,
67 rust_2018_idioms
68)]
69#![doc = include_str!("../README.md")]
73
74mod image;
75
76use proc_macro::TokenStream;
77use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
78
79use syn::{ImplItemFn, visit_mut::VisitMut};
80
81use quote::{ToTokens, TokenStreamExt, format_ident, quote};
82use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
83use std::fmt::Write;
84
85#[proc_macro]
135#[allow(nonstandard_style)]
138pub fn Image(item: TokenStream) -> TokenStream {
139 let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
140
141 output.into()
142}
143
144#[proc_macro_attribute]
147pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
148 let spirv = format_ident!("{}", &spirv_attr_with_version());
149
150 let attr: proc_macro2::TokenStream = attr.into();
152 let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
153
154 let item: proc_macro2::TokenStream = item.into();
155 for tt in item {
156 match tt {
157 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
158 let mut group_tokens = proc_macro2::TokenStream::new();
159 let mut last_token_hashtag = false;
160 for tt in group.stream() {
161 let is_token_hashtag =
162 matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
163 match tt {
164 TokenTree::Group(group)
165 if group.delimiter() == Delimiter::Bracket
166 && last_token_hashtag
167 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
168 {
169 let inner = group
172 .stream()
173 .into_iter()
174 .skip(1)
175 .collect::<proc_macro2::TokenStream>();
176 group_tokens.extend(
177 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
178 );
179 }
180 _ => group_tokens.append(tt),
181 }
182 last_token_hashtag = is_token_hashtag;
183 }
184 let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
185 out.set_span(group.span());
186 tokens.append(out);
187 }
188 _ => tokens.append(tt),
189 }
190 }
191 tokens.into()
192}
193
194#[proc_macro_attribute]
200pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
201 fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
202 let mut last_token_hashtag = false;
203 stream.into_iter().map(|tt| {
204 let mut is_token_hashtag = false;
205 let out = match tt {
206 TokenTree::Group(group)
207 if group.delimiter() == Delimiter::Bracket
208 && last_token_hashtag
209 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
210 {
211 let inner = group
214 .stream()
215 .into_iter()
216 .skip(1)
217 .collect::<proc_macro2::TokenStream>();
218 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
219 },
220 TokenTree::Group(group) => {
221 let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
222 out.set_span(group.span());
223 TokenTree::Group(out).into()
224 },
225 TokenTree::Punct(punct) => {
226 is_token_hashtag = punct.as_char() == '#';
227 TokenTree::Punct(punct).into()
228 }
229 tt => tt.into(),
230 };
231 last_token_hashtag = is_token_hashtag;
232 out
233 }).collect()
234 }
235
236 let attr: proc_macro2::TokenStream = attr.into();
237 let item: proc_macro2::TokenStream = item.into();
238
239 let spirv = format_ident!("{}", &spirv_attr_with_version());
241 let inner = recurse(&spirv, item);
242 quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
243}
244
245#[proc_macro_attribute]
248pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
249 let syn::ItemFn {
250 attrs,
251 vis,
252 sig,
253 block,
254 } = syn::parse_macro_input!(item as syn::ItemFn);
255
256 let fn_name = sig.ident.clone();
257
258 let sig_cpu = syn::Signature {
259 abi: None,
260 ..sig.clone()
261 };
262
263 let output = quote::quote! {
264 #[cfg(not(target_arch="spirv"))]
266 #[allow(unused_variables)]
267 #(#attrs)* #vis #sig_cpu {
268 unimplemented!(concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms."))
269 }
270
271 #[cfg(target_arch="spirv")]
272 #(#attrs)* #vis #sig {
273 #block
274 }
275 };
276
277 output.into()
278}
279
280#[proc_macro]
291pub fn debug_printf(input: TokenStream) -> TokenStream {
292 debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
293}
294
295#[proc_macro]
297pub fn debug_printfln(input: TokenStream) -> TokenStream {
298 let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
299 input.format_string.push('\n');
300 debug_printf_inner(input)
301}
302
303struct DebugPrintfInput {
304 span: Span,
305 format_string: String,
306 variables: Vec<syn::Expr>,
307}
308
309impl syn::parse::Parse for DebugPrintfInput {
310 fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
311 let span = input.span();
312
313 if input.is_empty() {
314 return Ok(Self {
315 span,
316 format_string: Default::default(),
317 variables: Default::default(),
318 });
319 }
320
321 let format_string = input.parse::<syn::LitStr>()?;
322 if !input.is_empty() {
323 input.parse::<syn::token::Comma>()?;
324 }
325 let variables =
326 syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
327
328 Ok(Self {
329 span,
330 format_string: format_string.value(),
331 variables: variables.into_iter().collect(),
332 })
333 }
334}
335
336fn parsing_error(message: &str, span: Span) -> TokenStream {
337 syn::Error::new(span, message).to_compile_error().into()
338}
339
340enum FormatType {
341 Scalar {
342 ty: proc_macro2::TokenStream,
343 },
344 Vector {
345 ty: proc_macro2::TokenStream,
346 width: usize,
347 },
348}
349
350fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
351 let DebugPrintfInput {
352 format_string,
353 variables,
354 span,
355 } = input;
356
357 fn map_specifier_to_type(
358 specifier: char,
359 chars: &mut std::str::Chars<'_>,
360 ) -> Option<proc_macro2::TokenStream> {
361 let mut peekable = chars.peekable();
362
363 Some(match specifier {
364 'd' | 'i' => quote::quote! { i32 },
365 'o' | 'x' | 'X' => quote::quote! { u32 },
366 'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
367 'u' => {
368 if matches!(peekable.peek(), Some('l')) {
369 chars.next();
370 quote::quote! { u64 }
371 } else {
372 quote::quote! { u32 }
373 }
374 }
375 'l' => {
376 if matches!(peekable.peek(), Some('u' | 'x')) {
377 chars.next();
378 quote::quote! { u64 }
379 } else {
380 return None;
381 }
382 }
383 _ => return None,
384 })
385 }
386
387 let mut chars = format_string.chars();
388 let mut format_arguments = Vec::new();
389
390 while let Some(mut ch) = chars.next() {
391 if ch == '%' {
392 ch = match chars.next() {
393 Some('%') => continue,
394 None => return parsing_error("Unterminated format specifier", span),
395 Some(ch) => ch,
396 };
397
398 let mut has_precision = false;
399
400 while ch.is_ascii_digit() {
401 ch = match chars.next() {
402 Some(ch) => ch,
403 None => {
404 return parsing_error(
405 "Unterminated format specifier: missing type after precision",
406 span,
407 );
408 }
409 };
410
411 has_precision = true;
412 }
413
414 if has_precision && ch == '.' {
415 ch = match chars.next() {
416 Some(ch) => ch,
417 None => {
418 return parsing_error(
419 "Unterminated format specifier: missing type after decimal point",
420 span,
421 );
422 }
423 };
424
425 while ch.is_ascii_digit() {
426 ch = match chars.next() {
427 Some(ch) => ch,
428 None => {
429 return parsing_error(
430 "Unterminated format specifier: missing type after fraction precision",
431 span,
432 );
433 }
434 };
435 }
436 }
437
438 if ch == 'v' {
439 let width = match chars.next() {
440 Some('2') => 2,
441 Some('3') => 3,
442 Some('4') => 4,
443 Some(ch) => {
444 return parsing_error(&format!("Invalid width for vector: {ch}"), span);
445 }
446 None => return parsing_error("Missing vector dimensions specifier", span),
447 };
448
449 ch = match chars.next() {
450 Some(ch) => ch,
451 None => return parsing_error("Missing vector type specifier", span),
452 };
453
454 let ty = match map_specifier_to_type(ch, &mut chars) {
455 Some(ty) => ty,
456 _ => {
457 return parsing_error(
458 &format!("Unrecognised vector type specifier: '{ch}'"),
459 span,
460 );
461 }
462 };
463
464 format_arguments.push(FormatType::Vector { ty, width });
465 } else {
466 let ty = match map_specifier_to_type(ch, &mut chars) {
467 Some(ty) => ty,
468 _ => {
469 return parsing_error(
470 &format!("Unrecognised format specifier: '{ch}'"),
471 span,
472 );
473 }
474 };
475
476 format_arguments.push(FormatType::Scalar { ty });
477 }
478 }
479 }
480
481 if format_arguments.len() != variables.len() {
482 return syn::Error::new(
483 span,
484 format!(
485 "{} % arguments were found, but {} variables were given",
486 format_arguments.len(),
487 variables.len()
488 ),
489 )
490 .to_compile_error()
491 .into();
492 }
493
494 let mut variable_idents = String::new();
495 let mut input_registers = Vec::new();
496 let mut op_loads = Vec::new();
497
498 for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
499 {
500 let ident = quote::format_ident!("_{}", i);
501
502 let _ = write!(variable_idents, "%{ident} ");
503
504 let assert_fn = match format_argument {
505 FormatType::Scalar { ty } => {
506 quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
507 }
508 FormatType::Vector { ty, width } => {
509 quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
510 }
511 };
512
513 input_registers.push(quote::quote! {
514 #ident = in(reg) &#assert_fn(#variable),
515 });
516
517 let op_load = format!("%{ident} = OpLoad _ {{{ident}}}");
518
519 op_loads.push(quote::quote! {
520 #op_load,
521 });
522 }
523
524 let input_registers = input_registers
525 .into_iter()
526 .collect::<proc_macro2::TokenStream>();
527 let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
528 let format_string = format_string.replace('{', "{{").replace('}', "}}");
532
533 let op_string = format!("%string = OpString {format_string:?}");
534
535 let output = quote::quote! {
536 ::core::arch::asm!(
537 "%void = OpTypeVoid",
538 #op_string,
539 "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
540 #op_loads
541 concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
542 #input_registers
543 )
544 };
545
546 output.into()
547}
548
549const SAMPLE_PARAM_COUNT: usize = 4;
550const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
551const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
552const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
553const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
554const SAMPLE_PARAM_GRAD_INDEX: usize = 2; const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; fn is_grad(i: usize) -> bool {
558 i == SAMPLE_PARAM_GRAD_INDEX
559}
560
561struct SampleImplRewriter(usize, syn::Type);
562
563impl SampleImplRewriter {
564 pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
565 let mut new_impl = f.clone();
566 let mut ty_str = String::from("SampleParams<");
567
568 for i in 0..SAMPLE_PARAM_COUNT {
571 if mask & (1 << i) != 0 {
572 new_impl.generics.params.push(syn::GenericParam::Type(
573 syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
574 ));
575 ty_str.push_str("SomeTy<");
576 ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
577 ty_str.push('>');
578 } else {
579 ty_str.push_str("NoneTy");
580 }
581 ty_str.push(',');
582 }
583 ty_str.push('>');
584 let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
585
586 if let Some(t) = &mut new_impl.trait_
589 && let syn::PathArguments::AngleBracketed(a) =
590 &mut t.1.segments.last_mut().unwrap().arguments
591 && let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
592 {
593 *t = ty.clone();
594 }
595
596 SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
598 new_impl
599 }
600
601 #[allow(clippy::needless_range_loop)]
603 fn get_operands(&self) -> String {
604 let mut op = String::new();
605 for i in 0..SAMPLE_PARAM_COUNT {
606 if self.0 & (1 << i) != 0 {
607 if is_grad(i) {
608 op.push_str("Grad %grad_x %grad_y ");
609 } else {
610 op.push_str(SAMPLE_PARAM_OPERANDS[i]);
611 op.push_str(" %");
612 op.push_str(SAMPLE_PARAM_NAMES[i]);
613 op.push(' ');
614 }
615 }
616 }
617 op
618 }
619
620 #[allow(clippy::needless_range_loop)]
622 fn add_loads(&self, t: &mut Vec<TokenTree>) {
623 for i in 0..SAMPLE_PARAM_COUNT {
624 if self.0 & (1 << i) != 0 {
625 if is_grad(i) {
626 t.push(TokenTree::Literal(proc_macro2::Literal::string(
627 "%grad_x = OpLoad _ {grad_x}",
628 )));
629 t.push(TokenTree::Punct(proc_macro2::Punct::new(
630 ',',
631 proc_macro2::Spacing::Alone,
632 )));
633 t.push(TokenTree::Literal(proc_macro2::Literal::string(
634 "%grad_y = OpLoad _ {grad_y}",
635 )));
636 t.push(TokenTree::Punct(proc_macro2::Punct::new(
637 ',',
638 proc_macro2::Spacing::Alone,
639 )));
640 } else {
641 let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
642 t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
643 t.push(TokenTree::Punct(proc_macro2::Punct::new(
644 ',',
645 proc_macro2::Spacing::Alone,
646 )));
647 }
648 }
649 }
650 }
651
652 #[allow(clippy::needless_range_loop)]
654 fn add_regs(&self, t: &mut Vec<TokenTree>) {
655 for i in 0..SAMPLE_PARAM_COUNT {
656 if self.0 & (1 << i) != 0 {
657 let s = if is_grad(i) {
661 "grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
662 .to_string()
663 } else {
664 format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
665 };
666 let ts: proc_macro2::TokenStream = s.parse().unwrap();
667 t.extend(ts);
668 }
669 }
670 }
671}
672
673impl VisitMut for SampleImplRewriter {
674 fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
675 if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
677 *p.ty.as_mut() = self.1.clone();
678 }
679 syn::visit_mut::visit_impl_item_fn_mut(self, item);
680 }
681
682 fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
683 if m.path.is_ident("asm") {
684 let t = m.tokens.clone();
686 let mut new_t = Vec::new();
687 let mut altered = false;
688
689 for tt in t {
690 match tt {
691 TokenTree::Literal(l) => {
692 if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
693 let s = l.value();
695 if s.contains("$PARAMS") {
696 altered = true;
697 self.add_loads(&mut new_t);
699 let s = s.replace("$PARAMS", &self.get_operands());
701 let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
702 "ExplicitLod"
703 } else {
704 "ImplicitLod "
705 };
706 let s = s.replace("$LOD", lod_type);
707
708 new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
709 s.as_str(),
710 )));
711 } else {
712 new_t.push(TokenTree::Literal(l.token()));
713 }
714 } else {
715 new_t.push(TokenTree::Literal(l));
716 }
717 }
718 _ => {
719 new_t.push(tt);
720 }
721 }
722 }
723
724 if altered {
725 self.add_regs(&mut new_t);
727 }
728
729 m.tokens = new_t.into_iter().collect();
731 }
732 }
733}
734
735#[proc_macro_attribute]
741#[doc(hidden)]
742pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
743 let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
744 let mut fns = Vec::new();
745
746 for m in 1..(1 << SAMPLE_PARAM_COUNT) {
747 fns.push(SampleImplRewriter::rewrite(m, &item_impl));
748 }
749
750 quote! { #(#fns)* }.into()
753}