1use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23 args: [u32; 3],
24 len: u8,
25}
26
27impl ExecutionModeExtra {
28 pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29 let _args = args.as_ref();
30 let mut args = [0; 3];
31 args[.._args.len()].copy_from_slice(_args);
32 let len = _args.len() as u8;
33 Self { args, len }
34 }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38 fn as_ref(&self) -> &[u32] {
39 &self.args[..self.len as _]
40 }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45 pub execution_model: ExecutionModel,
46 pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47 pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51 fn from(execution_model: ExecutionModel) -> Self {
52 Self {
53 execution_model,
54 execution_modes: Vec::new(),
55 name: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63 GenericImageType,
64 Sampler,
65 AccelerationStructureKhr,
66 SampledImage,
67 RayQueryKhr,
68 RuntimeArray,
69 TypedBuffer,
70 Matrix,
71 Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76 pub id: u32,
77 pub default: Option<u32>,
78 pub array_count: Option<u32>,
79}
80
81#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85 IntrinsicType(IntrinsicType),
87 Block,
88
89 Entry(Entry),
91
92 StorageClass(StorageClass),
94 Builtin(BuiltIn),
95 DescriptorSet(u32),
96 Binding(u32),
97 Location(u32),
98 Flat,
99 PerPrimitiveExt,
100 Invariant,
101 InputAttachmentIndex(u32),
102 SpecConstant(SpecConstant),
103
104 BufferLoadIntrinsic,
106 BufferStoreIntrinsic,
107}
108
109#[derive(Copy, Clone)]
112pub struct Spanned<T> {
113 pub value: T,
114 pub span: Span,
115}
116
117#[derive(Default)]
121pub struct AggregatedSpirvAttributes {
122 pub intrinsic_type: Option<Spanned<IntrinsicType>>,
124 pub block: Option<Spanned<()>>,
125
126 pub entry: Option<Spanned<Entry>>,
128
129 pub storage_class: Option<Spanned<StorageClass>>,
131 pub builtin: Option<Spanned<BuiltIn>>,
132 pub descriptor_set: Option<Spanned<u32>>,
133 pub binding: Option<Spanned<u32>>,
134 pub location: Option<Spanned<u32>>,
135 pub flat: Option<Spanned<()>>,
136 pub invariant: Option<Spanned<()>>,
137 pub per_primitive_ext: Option<Spanned<()>>,
138 pub input_attachment_index: Option<Spanned<u32>>,
139 pub spec_constant: Option<Spanned<SpecConstant>>,
140
141 pub buffer_load_intrinsic: Option<Spanned<()>>,
143 pub buffer_store_intrinsic: Option<Spanned<()>>,
144}
145
146struct MultipleAttrs {
147 prev_span: Span,
148 category: &'static str,
149}
150
151impl AggregatedSpirvAttributes {
152 pub fn parse<'tcx>(
157 cx: &CodegenCx<'tcx>,
158 attrs: impl IntoIterator<Item = &'tcx Attribute>,
159 ) -> Self {
160 let mut aggregated_attrs = Self::default();
161
162 for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
165 let (span, parsed_attr) = match parse_attr_result {
166 Ok(span_and_parsed_attr) => span_and_parsed_attr,
167 Err((span, msg)) => {
168 cx.tcx.dcx().span_delayed_bug(span, msg);
169 continue;
170 }
171 };
172 match aggregated_attrs.try_insert_attr(parsed_attr, span) {
173 Ok(()) => {}
174 Err(MultipleAttrs {
175 prev_span: _,
176 category,
177 }) => {
178 cx.tcx
179 .dcx()
180 .span_delayed_bug(span, format!("multiple {category} attributes"));
181 }
182 }
183 }
184
185 aggregated_attrs
186 }
187
188 fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
189 fn try_insert<T>(
190 slot: &mut Option<Spanned<T>>,
191 value: T,
192 span: Span,
193 category: &'static str,
194 ) -> Result<(), MultipleAttrs> {
195 if let Some(prev) = slot {
196 Err(MultipleAttrs {
197 prev_span: prev.span,
198 category,
199 })
200 } else {
201 *slot = Some(Spanned { value, span });
202 Ok(())
203 }
204 }
205
206 use SpirvAttribute::*;
207 match attr {
208 IntrinsicType(value) => {
209 try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
210 }
211 Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
212 Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
213 StorageClass(value) => {
214 try_insert(&mut self.storage_class, value, span, "storage class")
215 }
216 Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
217 DescriptorSet(value) => try_insert(
218 &mut self.descriptor_set,
219 value,
220 span,
221 "#[spirv(descriptor_set)]",
222 ),
223 Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
224 Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
225 Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
226 Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
227 PerPrimitiveExt => try_insert(
228 &mut self.per_primitive_ext,
229 (),
230 span,
231 "#[spirv(per_primitive_ext)]",
232 ),
233 InputAttachmentIndex(value) => try_insert(
234 &mut self.input_attachment_index,
235 value,
236 span,
237 "#[spirv(attachment_index)]",
238 ),
239 SpecConstant(value) => try_insert(
240 &mut self.spec_constant,
241 value,
242 span,
243 "#[spirv(spec_constant)]",
244 ),
245 BufferLoadIntrinsic => try_insert(
246 &mut self.buffer_load_intrinsic,
247 (),
248 span,
249 "#[spirv(buffer_load_intrinsic)]",
250 ),
251 BufferStoreIntrinsic => try_insert(
252 &mut self.buffer_store_intrinsic,
253 (),
254 span,
255 "#[spirv(buffer_store_intrinsic)]",
256 ),
257 }
258 }
259}
260
261fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
263 match impl_item.kind {
264 hir::ImplItemKind::Const(..) => Target::AssocConst,
265 hir::ImplItemKind::Fn(..) => {
266 let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
267 let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
268 let containing_impl_is_for_trait = match &containing_item.kind {
269 hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
270 _ => unreachable!("parent of an ImplItem must be an Impl"),
271 };
272 if containing_impl_is_for_trait {
273 Target::Method(MethodKind::Trait { body: true })
274 } else {
275 Target::Method(MethodKind::Inherent)
276 }
277 }
278 hir::ImplItemKind::Type(..) => Target::AssocTy,
279 }
280}
281
282struct CheckSpirvAttrVisitor<'tcx> {
283 tcx: TyCtxt<'tcx>,
284 sym: Rc<Symbols>,
285}
286
287impl CheckSpirvAttrVisitor<'_> {
288 fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
289 let mut aggregated_attrs = AggregatedSpirvAttributes::default();
290
291 let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
292
293 let attrs = self.tcx.hir_attrs(hir_id);
294 for parse_attr_result in parse_attrs(attrs) {
295 let (span, parsed_attr) = match parse_attr_result {
296 Ok(span_and_parsed_attr) => span_and_parsed_attr,
297 Err((span, msg)) => {
298 self.tcx.dcx().span_err(span, msg);
299 continue;
300 }
301 };
302
303 struct Expected<T>(T);
305
306 let valid_target = match parsed_attr {
307 SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
308 Target::Struct => {
309 Ok(())
312 }
313
314 _ => Err(Expected("struct")),
315 },
316
317 SpirvAttribute::Entry(_) => match target {
318 Target::Fn
319 | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
320 Ok(())
323 }
324
325 _ => Err(Expected("function")),
326 },
327
328 SpirvAttribute::StorageClass(_)
329 | SpirvAttribute::Builtin(_)
330 | SpirvAttribute::DescriptorSet(_)
331 | SpirvAttribute::Binding(_)
332 | SpirvAttribute::Location(_)
333 | SpirvAttribute::Flat
334 | SpirvAttribute::Invariant
335 | SpirvAttribute::PerPrimitiveExt
336 | SpirvAttribute::InputAttachmentIndex(_)
337 | SpirvAttribute::SpecConstant(_) => match target {
338 Target::Param => {
339 let parent_hir_id = self.tcx.parent_hir_id(hir_id);
340 let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
341 .filter_map(|r| r.ok())
342 .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
343 if !parent_is_entry_point {
344 self.tcx.dcx().span_err(
345 span,
346 "attribute is only valid on a parameter of an entry-point function",
347 );
348 } else {
349 if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
352 let valid = match storage_class {
353 StorageClass::Input | StorageClass::Output => {
354 Err("is the default and should not be explicitly specified")
355 }
356
357 StorageClass::Private
358 | StorageClass::Function
359 | StorageClass::Generic => {
360 Err("can not be used as part of an entry's interface")
361 }
362
363 _ => Ok(()),
364 };
365
366 if let Err(msg) = valid {
367 self.tcx.dcx().span_err(
368 span,
369 format!("`{storage_class:?}` storage class {msg}"),
370 );
371 }
372 }
373 }
374 Ok(())
375 }
376
377 _ => Err(Expected("function parameter")),
378 },
379 SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
380 match target {
381 Target::Fn => Ok(()),
382 _ => Err(Expected("function")),
383 }
384 }
385 };
386 match valid_target {
387 Err(Expected(expected_target)) => {
388 self.tcx.dcx().span_err(
389 span,
390 format!(
391 "attribute is only valid on a {expected_target}, not on a {target}"
392 ),
393 );
394 }
395 Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
396 Ok(()) => {}
397 Err(MultipleAttrs {
398 prev_span,
399 category,
400 }) => {
401 self.tcx
402 .dcx()
403 .struct_span_err(
404 span,
405 format!("only one {category} attribute is allowed on a {target}"),
406 )
407 .with_span_note(prev_span, format!("previous {category} attribute"))
408 .emit();
409 }
410 },
411 }
412 }
413
414 if let Some(block_attr) = aggregated_attrs.block {
418 self.tcx.dcx().span_warn(
419 block_attr.span,
420 "#[spirv(block)] is no longer needed and should be removed",
421 );
422 }
423 }
424}
425
426impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
428 type NestedFilter = nested_filter::OnlyBodies;
429
430 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
431 self.tcx
432 }
433
434 fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
435 let target = Target::from_item(item);
436 self.check_spirv_attributes(item.hir_id(), target);
437 intravisit::walk_item(self, item);
438 }
439
440 fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
441 let target = Target::from_generic_param(generic_param);
442 self.check_spirv_attributes(generic_param.hir_id, target);
443 intravisit::walk_generic_param(self, generic_param);
444 }
445
446 fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
447 let target = Target::from_trait_item(trait_item);
448 self.check_spirv_attributes(trait_item.hir_id(), target);
449 intravisit::walk_trait_item(self, trait_item);
450 }
451
452 fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
453 self.check_spirv_attributes(field.hir_id, Target::Field);
454 intravisit::walk_field_def(self, field);
455 }
456
457 fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
458 self.check_spirv_attributes(arm.hir_id, Target::Arm);
459 intravisit::walk_arm(self, arm);
460 }
461
462 fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
463 let target = Target::from_foreign_item(f_item);
464 self.check_spirv_attributes(f_item.hir_id(), target);
465 intravisit::walk_foreign_item(self, f_item);
466 }
467
468 fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
469 let target = target_from_impl_item(self.tcx, impl_item);
470 self.check_spirv_attributes(impl_item.hir_id(), target);
471 intravisit::walk_impl_item(self, impl_item);
472 }
473
474 fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
475 if let hir::StmtKind::Let(l) = stmt.kind {
477 self.check_spirv_attributes(l.hir_id, Target::Statement);
478 }
479 intravisit::walk_stmt(self, stmt);
480 }
481
482 fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
483 let target = match expr.kind {
484 hir::ExprKind::Closure { .. } => Target::Closure,
485 _ => Target::Expression,
486 };
487
488 self.check_spirv_attributes(expr.hir_id, target);
489 intravisit::walk_expr(self, expr);
490 }
491
492 fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
493 self.check_spirv_attributes(variant.hir_id, Target::Variant);
494 intravisit::walk_variant(self, variant);
495 }
496
497 fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
498 self.check_spirv_attributes(param.hir_id, Target::Param);
499
500 intravisit::walk_param(self, param);
501 }
502}
503
504fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
506 let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
507 tcx,
508 sym: Symbols::get(),
509 };
510 tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
511 if module_def_id.is_top_level_module() {
512 check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
513 }
514}
515
516pub(crate) fn provide(providers: &mut Providers) {
517 *providers = Providers {
518 check_mod_attrs: |tcx, module_def_id| {
519 (rustc_interface::DEFAULT_QUERY_PROVIDERS
521 .queries
522 .check_mod_attrs)(tcx, module_def_id);
523 check_mod_attrs(tcx, module_def_id);
524 },
525 ..*providers
526 };
527}
528
529type ParseAttrError = (Span, String);
531
532#[allow(clippy::get_first)]
533fn parse_attrs_for_checking<'sym, 'attr, I>(
534 sym: &'sym Symbols,
535 attrs: I,
536) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'sym
537where
538 I: IntoIterator<Item = &'attr Attribute> + 'sym,
539 I::IntoIter: 'sym,
540 'attr: 'sym,
541{
542 attrs
543 .into_iter()
544 .map(move |attr| {
545 match attr {
547 Attribute::Unparsed(item) => {
548 let s = &item.path.segments;
550 if let Some(rust_gpu) = s.get(0) && *rust_gpu == sym.rust_gpu {
551 match s.get(1) {
553 Some(command) if *command == sym.spirv_attr_with_version => {
554 if let Some(args) = attr.meta_item_list() {
556 Ok(parse_spirv_attr(sym, args.iter()))
558 } else {
559 Err((
561 attr.span(),
562 "#[spirv(..)] attribute must have at least one argument"
563 .to_string(),
564 ))
565 }
566 }
567 Some(command) if *command == sym.vector => {
568 match s.get(2) {
570 Some(version) if *version == sym.v1 => {
572 Ok(SmallVec::from_iter([
573 Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
574 ]))
575 },
576 _ => Err((
577 attr.span(),
578 "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
579 .to_string(),
580 )),
581 }
582 }
583 _ => {
584 let spirv = sym.spirv_attr_with_version.as_str();
586 Err((
587 attr.span(),
588 format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
589 Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
590 ))
591 }
592 }
593 } else {
594 Ok(Default::default())
596 }
597 }
598 Attribute::Parsed(_) => Ok(Default::default()),
599 }
600 })
601 .flat_map(|result| {
602 result
603 .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
604 .into_iter()
605 })
606}
607
608fn parse_spirv_attr<'a>(
609 sym: &Symbols,
610 iter: impl Iterator<Item = &'a MetaItemInner>,
611) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
612 iter.map(|arg| {
613 let span = arg.span();
614 let parsed_attr =
615 if arg.has_name(sym.descriptor_set) {
616 SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
617 } else if arg.has_name(sym.binding) {
618 SpirvAttribute::Binding(parse_attr_int_value(arg)?)
619 } else if arg.has_name(sym.location) {
620 SpirvAttribute::Location(parse_attr_int_value(arg)?)
621 } else if arg.has_name(sym.input_attachment_index) {
622 SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
623 } else if arg.has_name(sym.spec_constant) {
624 SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
625 } else {
626 let name = match arg.ident() {
627 Some(i) => i,
628 None => {
629 return Err((
630 span,
631 "#[spirv(..)] attribute argument must be single identifier".to_string(),
632 ));
633 }
634 };
635 sym.attributes.get(&name.name).map_or_else(
636 || Err((name.span, "unknown argument to spirv attribute".to_string())),
637 |a| {
638 Ok(match a {
639 SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
640 parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
641 ),
642 _ => a.clone(),
643 })
644 },
645 )?
646 };
647 Ok((span, parsed_attr))
648 })
649 .collect()
650}
651
652fn parse_spec_constant_attr(
653 sym: &Symbols,
654 arg: &MetaItemInner,
655) -> Result<SpecConstant, ParseAttrError> {
656 let mut id = None;
657 let mut default = None;
658
659 if let Some(attrs) = arg.meta_item_list() {
660 for attr in attrs {
661 if attr.has_name(sym.id) {
662 if id.is_none() {
663 id = Some(parse_attr_int_value(attr)?);
664 } else {
665 return Err((attr.span(), "`id` may only be specified once".into()));
666 }
667 } else if attr.has_name(sym.default) {
668 if default.is_none() {
669 default = Some(parse_attr_int_value(attr)?);
670 } else {
671 return Err((attr.span(), "`default` may only be specified once".into()));
672 }
673 } else {
674 return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
675 }
676 }
677 }
678 Ok(SpecConstant {
679 id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
680 default,
681 array_count: None,
683 })
684}
685
686fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
687 let arg = match arg.meta_item() {
688 Some(arg) => arg,
689 None => return Err((arg.span(), "attribute must have value".to_string())),
690 };
691 match arg.name_value_literal() {
692 Some(&MetaItemLit {
693 kind: LitKind::Int(x, ..),
694 ..
695 }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
696 _ => Err((arg.span, "attribute value must be integer".to_string())),
697 }
698}
699
700fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
701 let arg = match arg.meta_item() {
702 Some(arg) => arg,
703 None => return Err((arg.span(), "attribute must have value".to_string())),
704 };
705 match arg.meta_item_list() {
706 Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
707 let mut local_size = [1; 3];
708 for (idx, lit) in tuple.iter().enumerate() {
709 match lit {
710 MetaItemInner::Lit(MetaItemLit {
711 kind: LitKind::Int(x, ..),
712 ..
713 }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
714 _ => return Err((lit.span(), "must be a u32 literal".to_string())),
715 }
716 }
717 Ok(local_size)
718 }
719 Some([]) => Err((
720 arg.span,
721 "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
722 )),
723 Some(tuple) if tuple.len() > 3 => Err((
724 arg.span,
725 "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
726 )),
727 _ => Err((
728 arg.span,
729 "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
730 )),
731 }
732}
733
734fn parse_entry_attrs(
739 sym: &Symbols,
740 arg: &MetaItemInner,
741 name: &Ident,
742 execution_model: ExecutionModel,
743) -> Result<Entry, ParseAttrError> {
744 use ExecutionMode::*;
745 use ExecutionModel::*;
746 let mut entry = Entry::from(execution_model);
747 let mut origin_mode: Option<ExecutionMode> = None;
748 let mut local_size: Option<[u32; 3]> = None;
749 let mut local_size_hint: Option<[u32; 3]> = None;
750 if let Some(attrs) = arg.meta_item_list() {
753 for attr in attrs {
754 if let Some(attr_name) = attr.ident() {
755 if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
756 {
757 use crate::symbols::ExecutionModeExtraDim::*;
758 let val = match extra_dim {
759 None | Tuple => Option::None,
760 _ => Some(parse_attr_int_value(attr)?),
761 };
762 match execution_mode {
763 OriginUpperLeft | OriginLowerLeft => {
764 origin_mode.replace(*execution_mode);
765 }
766 LocalSize => {
767 if local_size.is_none() {
768 local_size.replace(parse_local_size_attr(attr)?);
769 } else {
770 return Err((
771 attr_name.span,
772 String::from(
773 "`#[spirv(compute(threads))]` may only be specified once",
774 ),
775 ));
776 }
777 }
778 LocalSizeHint => {
779 let val = val.unwrap();
780 if local_size_hint.is_none() {
781 local_size_hint.replace([1, 1, 1]);
782 }
783 let local_size_hint = local_size_hint.as_mut().unwrap();
784 match extra_dim {
785 X => {
786 local_size_hint[0] = val;
787 }
788 Y => {
789 local_size_hint[1] = val;
790 }
791 Z => {
792 local_size_hint[2] = val;
793 }
794 _ => unreachable!(),
795 }
796 }
797 _ => {
819 if let Some(val) = val {
820 entry
821 .execution_modes
822 .push((*execution_mode, ExecutionModeExtra::new([val])));
823 } else {
824 entry
825 .execution_modes
826 .push((*execution_mode, ExecutionModeExtra::new([])));
827 }
828 }
829 }
830 } else if attr_name.name == sym.entry_point_name {
831 match attr.value_str() {
832 Some(sym) => {
833 entry.name = Some(sym);
834 }
835 None => {
836 return Err((
837 attr_name.span,
838 format!(
839 "#[spirv({name}(..))] unknown attribute argument {attr_name}"
840 ),
841 ));
842 }
843 }
844 } else {
845 return Err((
846 attr_name.span,
847 format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
848 ));
849 }
850 } else {
851 return Err((
852 arg.span(),
853 format!("#[spirv({name}(..))] attribute argument must be single identifier"),
854 ));
855 }
856 }
857 }
858 match entry.execution_model {
859 Fragment => {
860 let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
861 entry
862 .execution_modes
863 .push((origin_mode, ExecutionModeExtra::new([])));
864 }
865 GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
866 if let Some(local_size) = local_size {
867 entry
868 .execution_modes
869 .push((LocalSize, ExecutionModeExtra::new(local_size)));
870 } else {
871 return Err((
872 arg.span(),
873 String::from(
874 "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
875 ),
876 ));
877 }
878 }
879 _ => {}
881 }
882 Ok(entry)
883}