1#![deny(unsafe_code)]
7#![warn(
8 clippy::all,
9 clippy::await_holding_lock,
10 clippy::char_lit_as_u8,
11 clippy::checked_conversions,
12 clippy::dbg_macro,
13 clippy::debug_assert_with_mut_call,
14 clippy::doc_markdown,
15 clippy::empty_enum,
16 clippy::enum_glob_use,
17 clippy::exit,
18 clippy::expl_impl_clone_on_copy,
19 clippy::explicit_deref_methods,
20 clippy::explicit_into_iter_loop,
21 clippy::fallible_impl_from,
22 clippy::filter_map_next,
23 clippy::float_cmp_const,
24 clippy::fn_params_excessive_bools,
25 clippy::if_let_mutex,
26 clippy::implicit_clone,
27 clippy::imprecise_flops,
28 clippy::inefficient_to_string,
29 clippy::invalid_upcast_comparisons,
30 clippy::large_types_passed_by_value,
31 clippy::let_unit_value,
32 clippy::linkedlist,
33 clippy::lossy_float_literal,
34 clippy::macro_use_imports,
35 clippy::manual_ok_or,
36 clippy::map_err_ignore,
37 clippy::map_flatten,
38 clippy::map_unwrap_or,
39 clippy::match_same_arms,
40 clippy::match_wildcard_for_single_variants,
41 clippy::mem_forget,
42 clippy::mut_mut,
43 clippy::mutex_integer,
44 clippy::needless_borrow,
45 clippy::needless_continue,
46 clippy::option_option,
47 clippy::path_buf_push_overwrite,
48 clippy::ptr_as_ptr,
49 clippy::ref_option_ref,
50 clippy::rest_pat_in_fully_bound_structs,
51 clippy::same_functions_in_if_condition,
52 clippy::semicolon_if_nothing_returned,
53 clippy::string_add_assign,
54 clippy::string_add,
55 clippy::string_lit_as_bytes,
56 clippy::string_to_string,
57 clippy::todo,
58 clippy::trait_duplication_in_bounds,
59 clippy::unimplemented,
60 clippy::unnested_or_patterns,
61 clippy::unused_self,
62 clippy::useless_transmute,
63 clippy::verbose_file_reads,
64 clippy::zero_sized_map_values,
65 future_incompatible,
66 nonstandard_style,
67 rust_2018_idioms
68)]
69#![doc = include_str!("../README.md")]
73
74mod image;
75
76use proc_macro::TokenStream;
77use proc_macro2::{Delimiter, Group, Span, TokenTree};
78
79use syn::{ImplItemFn, visit_mut::VisitMut};
80
81use quote::{ToTokens, quote};
82use std::fmt::Write;
83
84#[proc_macro]
133#[allow(nonstandard_style)]
136pub fn Image(item: TokenStream) -> TokenStream {
137 let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
138
139 output.into()
140}
141
142#[proc_macro_attribute]
145pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
146 let mut tokens: Vec<TokenTree> = Vec::new();
147
148 let attr: proc_macro2::TokenStream = attr.into();
150 tokens.extend(quote! { #[cfg_attr(target_arch="spirv", rust_gpu::spirv(#attr))] });
151
152 let item: proc_macro2::TokenStream = item.into();
153 for tt in item {
154 match tt {
155 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
156 let mut sub_tokens = Vec::new();
157 for tt in group.stream() {
158 match tt {
159 TokenTree::Group(group)
160 if group.delimiter() == Delimiter::Bracket
161 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv")
162 && matches!(sub_tokens.last(), Some(TokenTree::Punct(p)) if p.as_char() == '#') =>
163 {
164 let inner = group.stream(); sub_tokens.extend(
167 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#inner)] },
168 );
169 }
170 _ => sub_tokens.push(tt),
171 }
172 }
173 tokens.push(TokenTree::from(Group::new(
174 Delimiter::Parenthesis,
175 sub_tokens.into_iter().collect(),
176 )));
177 }
178 _ => tokens.push(tt),
179 }
180 }
181 tokens
182 .into_iter()
183 .collect::<proc_macro2::TokenStream>()
184 .into()
185}
186
187#[proc_macro_attribute]
190pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
191 let syn::ItemFn {
192 attrs,
193 vis,
194 sig,
195 block,
196 } = syn::parse_macro_input!(item as syn::ItemFn);
197
198 #[allow(clippy::redundant_clone)]
200 let fn_name = sig.ident.clone();
201
202 let sig_cpu = syn::Signature {
203 abi: None,
204 ..sig.clone()
205 };
206
207 let output = quote::quote! {
208 #[cfg(not(target_arch="spirv"))]
210 #[allow(unused_variables)]
211 #(#attrs)* #vis #sig_cpu {
212 unimplemented!(concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms."))
213 }
214
215 #[cfg(target_arch="spirv")]
216 #(#attrs)* #vis #sig {
217 #block
218 }
219 };
220
221 output.into()
222}
223
224#[proc_macro]
235pub fn debug_printf(input: TokenStream) -> TokenStream {
236 debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
237}
238
239#[proc_macro]
241pub fn debug_printfln(input: TokenStream) -> TokenStream {
242 let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
243 input.format_string.push('\n');
244 debug_printf_inner(input)
245}
246
247struct DebugPrintfInput {
248 span: Span,
249 format_string: String,
250 variables: Vec<syn::Expr>,
251}
252
253impl syn::parse::Parse for DebugPrintfInput {
254 fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
255 let span = input.span();
256
257 if input.is_empty() {
258 return Ok(Self {
259 span,
260 format_string: Default::default(),
261 variables: Default::default(),
262 });
263 }
264
265 let format_string = input.parse::<syn::LitStr>()?;
266 if !input.is_empty() {
267 input.parse::<syn::token::Comma>()?;
268 }
269 let variables =
270 syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
271
272 Ok(Self {
273 span,
274 format_string: format_string.value(),
275 variables: variables.into_iter().collect(),
276 })
277 }
278}
279
280fn parsing_error(message: &str, span: Span) -> TokenStream {
281 syn::Error::new(span, message).to_compile_error().into()
282}
283
284enum FormatType {
285 Scalar {
286 ty: proc_macro2::TokenStream,
287 },
288 Vector {
289 ty: proc_macro2::TokenStream,
290 width: usize,
291 },
292}
293
294fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
295 let DebugPrintfInput {
296 format_string,
297 variables,
298 span,
299 } = input;
300
301 fn map_specifier_to_type(
302 specifier: char,
303 chars: &mut std::str::Chars<'_>,
304 ) -> Option<proc_macro2::TokenStream> {
305 let mut peekable = chars.peekable();
306
307 Some(match specifier {
308 'd' | 'i' => quote::quote! { i32 },
309 'o' | 'x' | 'X' => quote::quote! { u32 },
310 'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
311 'u' => {
312 if matches!(peekable.peek(), Some('l')) {
313 chars.next();
314 quote::quote! { u64 }
315 } else {
316 quote::quote! { u32 }
317 }
318 }
319 'l' => {
320 if matches!(peekable.peek(), Some('u' | 'x')) {
321 chars.next();
322 quote::quote! { u64 }
323 } else {
324 return None;
325 }
326 }
327 _ => return None,
328 })
329 }
330
331 let mut chars = format_string.chars();
332 let mut format_arguments = Vec::new();
333
334 while let Some(mut ch) = chars.next() {
335 if ch == '%' {
336 ch = match chars.next() {
337 Some('%') => continue,
338 None => return parsing_error("Unterminated format specifier", span),
339 Some(ch) => ch,
340 };
341
342 let mut has_precision = false;
343
344 while ch.is_ascii_digit() {
345 ch = match chars.next() {
346 Some(ch) => ch,
347 None => {
348 return parsing_error(
349 "Unterminated format specifier: missing type after precision",
350 span,
351 );
352 }
353 };
354
355 has_precision = true;
356 }
357
358 if has_precision && ch == '.' {
359 ch = match chars.next() {
360 Some(ch) => ch,
361 None => {
362 return parsing_error(
363 "Unterminated format specifier: missing type after decimal point",
364 span,
365 );
366 }
367 };
368
369 while ch.is_ascii_digit() {
370 ch = match chars.next() {
371 Some(ch) => ch,
372 None => {
373 return parsing_error(
374 "Unterminated format specifier: missing type after fraction precision",
375 span,
376 );
377 }
378 };
379 }
380 }
381
382 if ch == 'v' {
383 let width = match chars.next() {
384 Some('2') => 2,
385 Some('3') => 3,
386 Some('4') => 4,
387 Some(ch) => {
388 return parsing_error(&format!("Invalid width for vector: {ch}"), span);
389 }
390 None => return parsing_error("Missing vector dimensions specifier", span),
391 };
392
393 ch = match chars.next() {
394 Some(ch) => ch,
395 None => return parsing_error("Missing vector type specifier", span),
396 };
397
398 let ty = match map_specifier_to_type(ch, &mut chars) {
399 Some(ty) => ty,
400 _ => {
401 return parsing_error(
402 &format!("Unrecognised vector type specifier: '{ch}'"),
403 span,
404 );
405 }
406 };
407
408 format_arguments.push(FormatType::Vector { ty, width });
409 } else {
410 let ty = match map_specifier_to_type(ch, &mut chars) {
411 Some(ty) => ty,
412 _ => {
413 return parsing_error(
414 &format!("Unrecognised format specifier: '{ch}'"),
415 span,
416 );
417 }
418 };
419
420 format_arguments.push(FormatType::Scalar { ty });
421 }
422 }
423 }
424
425 if format_arguments.len() != variables.len() {
426 return syn::Error::new(
427 span,
428 format!(
429 "{} % arguments were found, but {} variables were given",
430 format_arguments.len(),
431 variables.len()
432 ),
433 )
434 .to_compile_error()
435 .into();
436 }
437
438 let mut variable_idents = String::new();
439 let mut input_registers = Vec::new();
440 let mut op_loads = Vec::new();
441
442 for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
443 {
444 let ident = quote::format_ident!("_{}", i);
445
446 let _ = write!(variable_idents, "%{ident} ");
447
448 let assert_fn = match format_argument {
449 FormatType::Scalar { ty } => {
450 quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
451 }
452 FormatType::Vector { ty, width } => {
453 quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
454 }
455 };
456
457 input_registers.push(quote::quote! {
458 #ident = in(reg) &#assert_fn(#variable),
459 });
460
461 let op_load = format!("%{ident} = OpLoad _ {{{ident}}}");
462
463 op_loads.push(quote::quote! {
464 #op_load,
465 });
466 }
467
468 let input_registers = input_registers
469 .into_iter()
470 .collect::<proc_macro2::TokenStream>();
471 let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
472 let format_string = format_string.replace('{', "{{").replace('}', "}}");
476
477 let op_string = format!("%string = OpString {format_string:?}");
478
479 let output = quote::quote! {
480 ::core::arch::asm!(
481 "%void = OpTypeVoid",
482 #op_string,
483 "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
484 #op_loads
485 concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
486 #input_registers
487 )
488 };
489
490 output.into()
491}
492
493const SAMPLE_PARAM_COUNT: usize = 4;
494const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
495const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
496const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
497const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
498const SAMPLE_PARAM_GRAD_INDEX: usize = 2; const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; fn is_grad(i: usize) -> bool {
502 i == SAMPLE_PARAM_GRAD_INDEX
503}
504
505struct SampleImplRewriter(usize, syn::Type);
506
507impl SampleImplRewriter {
508 pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
509 let mut new_impl = f.clone();
510 let mut ty_str = String::from("SampleParams<");
511
512 for i in 0..SAMPLE_PARAM_COUNT {
515 if mask & (1 << i) != 0 {
516 new_impl.generics.params.push(syn::GenericParam::Type(
517 syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
518 ));
519 ty_str.push_str("SomeTy<");
520 ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
521 ty_str.push('>');
522 } else {
523 ty_str.push_str("NoneTy");
524 }
525 ty_str.push(',');
526 }
527 ty_str.push('>');
528 let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
529
530 if let Some(t) = &mut new_impl.trait_
533 && let syn::PathArguments::AngleBracketed(a) =
534 &mut t.1.segments.last_mut().unwrap().arguments
535 && let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
536 {
537 *t = ty.clone();
538 }
539
540 SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
542 new_impl
543 }
544
545 #[allow(clippy::needless_range_loop)]
547 fn get_operands(&self) -> String {
548 let mut op = String::new();
549 for i in 0..SAMPLE_PARAM_COUNT {
550 if self.0 & (1 << i) != 0 {
551 if is_grad(i) {
552 op.push_str("Grad %grad_x %grad_y ");
553 } else {
554 op.push_str(SAMPLE_PARAM_OPERANDS[i]);
555 op.push_str(" %");
556 op.push_str(SAMPLE_PARAM_NAMES[i]);
557 op.push(' ');
558 }
559 }
560 }
561 op
562 }
563
564 #[allow(clippy::needless_range_loop)]
566 fn add_loads(&self, t: &mut Vec<TokenTree>) {
567 for i in 0..SAMPLE_PARAM_COUNT {
568 if self.0 & (1 << i) != 0 {
569 if is_grad(i) {
570 t.push(TokenTree::Literal(proc_macro2::Literal::string(
571 "%grad_x = OpLoad _ {grad_x}",
572 )));
573 t.push(TokenTree::Punct(proc_macro2::Punct::new(
574 ',',
575 proc_macro2::Spacing::Alone,
576 )));
577 t.push(TokenTree::Literal(proc_macro2::Literal::string(
578 "%grad_y = OpLoad _ {grad_y}",
579 )));
580 t.push(TokenTree::Punct(proc_macro2::Punct::new(
581 ',',
582 proc_macro2::Spacing::Alone,
583 )));
584 } else {
585 let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
586 t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
587 t.push(TokenTree::Punct(proc_macro2::Punct::new(
588 ',',
589 proc_macro2::Spacing::Alone,
590 )));
591 }
592 }
593 }
594 }
595
596 #[allow(clippy::needless_range_loop)]
598 fn add_regs(&self, t: &mut Vec<TokenTree>) {
599 for i in 0..SAMPLE_PARAM_COUNT {
600 if self.0 & (1 << i) != 0 {
601 let s = if is_grad(i) {
605 "grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
606 .to_string()
607 } else {
608 format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
609 };
610 let ts: proc_macro2::TokenStream = s.parse().unwrap();
611 t.extend(ts);
612 }
613 }
614 }
615}
616
617impl VisitMut for SampleImplRewriter {
618 fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
619 if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
621 *p.ty.as_mut() = self.1.clone();
622 }
623 syn::visit_mut::visit_impl_item_fn_mut(self, item);
624 }
625
626 fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
627 if m.path.is_ident("asm") {
628 let t = m.tokens.clone();
630 let mut new_t = Vec::new();
631 let mut altered = false;
632
633 for tt in t {
634 match tt {
635 TokenTree::Literal(l) => {
636 if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
637 let s = l.value();
639 if s.contains("$PARAMS") {
640 altered = true;
641 self.add_loads(&mut new_t);
643 let s = s.replace("$PARAMS", &self.get_operands());
645 let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
646 "ExplicitLod"
647 } else {
648 "ImplicitLod "
649 };
650 let s = s.replace("$LOD", lod_type);
651
652 new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
653 s.as_str(),
654 )));
655 } else {
656 new_t.push(TokenTree::Literal(l.token()));
657 }
658 } else {
659 new_t.push(TokenTree::Literal(l));
660 }
661 }
662 _ => {
663 new_t.push(tt);
664 }
665 }
666 }
667
668 if altered {
669 self.add_regs(&mut new_t);
671 }
672
673 m.tokens = new_t.into_iter().collect();
675 }
676 }
677}
678
679#[proc_macro_attribute]
685#[doc(hidden)]
686pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
687 let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
688 let mut fns = Vec::new();
689
690 for m in 1..(1 << SAMPLE_PARAM_COUNT) {
691 fns.push(SampleImplRewriter::rewrite(m, &item_impl));
692 }
693
694 quote! { #(#fns)* }.into()
697}