1use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23 args: [u32; 3],
24 len: u8,
25}
26
27impl ExecutionModeExtra {
28 pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29 let _args = args.as_ref();
30 let mut args = [0; 3];
31 args[.._args.len()].copy_from_slice(_args);
32 let len = _args.len() as u8;
33 Self { args, len }
34 }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38 fn as_ref(&self) -> &[u32] {
39 &self.args[..self.len as _]
40 }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45 pub execution_model: ExecutionModel,
46 pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47 pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51 fn from(execution_model: ExecutionModel) -> Self {
52 Self {
53 execution_model,
54 execution_modes: Vec::new(),
55 name: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63 GenericImageType,
64 Sampler,
65 AccelerationStructureKhr,
66 SampledImage,
67 RayQueryKhr,
68 RuntimeArray,
69 TypedBuffer,
70 Matrix,
71 Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76 pub id: u32,
77 pub default: Option<u32>,
78 pub array_count: Option<u32>,
79}
80
81#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85 IntrinsicType(IntrinsicType),
87 Block,
88
89 Entry(Entry),
91
92 StorageClass(StorageClass),
94 Builtin(BuiltIn),
95 DescriptorSet(u32),
96 Binding(u32),
97 Location(u32),
98 Flat,
99 PerPrimitiveExt,
100 Invariant,
101 InputAttachmentIndex(u32),
102 SpecConstant(SpecConstant),
103
104 BufferLoadIntrinsic,
106 BufferStoreIntrinsic,
107}
108
109#[derive(Copy, Clone)]
112pub struct Spanned<T> {
113 pub value: T,
114 pub span: Span,
115}
116
117#[derive(Default)]
121pub struct AggregatedSpirvAttributes {
122 pub intrinsic_type: Option<Spanned<IntrinsicType>>,
124 pub block: Option<Spanned<()>>,
125
126 pub entry: Option<Spanned<Entry>>,
128
129 pub storage_class: Option<Spanned<StorageClass>>,
131 pub builtin: Option<Spanned<BuiltIn>>,
132 pub descriptor_set: Option<Spanned<u32>>,
133 pub binding: Option<Spanned<u32>>,
134 pub location: Option<Spanned<u32>>,
135 pub flat: Option<Spanned<()>>,
136 pub invariant: Option<Spanned<()>>,
137 pub per_primitive_ext: Option<Spanned<()>>,
138 pub input_attachment_index: Option<Spanned<u32>>,
139 pub spec_constant: Option<Spanned<SpecConstant>>,
140
141 pub buffer_load_intrinsic: Option<Spanned<()>>,
143 pub buffer_store_intrinsic: Option<Spanned<()>>,
144}
145
146struct MultipleAttrs {
147 prev_span: Span,
148 category: &'static str,
149}
150
151impl AggregatedSpirvAttributes {
152 pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
157 let mut aggregated_attrs = Self::default();
158
159 for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
162 let (span, parsed_attr) = match parse_attr_result {
163 Ok(span_and_parsed_attr) => span_and_parsed_attr,
164 Err((span, msg)) => {
165 cx.tcx.dcx().span_delayed_bug(span, msg);
166 continue;
167 }
168 };
169 match aggregated_attrs.try_insert_attr(parsed_attr, span) {
170 Ok(()) => {}
171 Err(MultipleAttrs {
172 prev_span: _,
173 category,
174 }) => {
175 cx.tcx
176 .dcx()
177 .span_delayed_bug(span, format!("multiple {category} attributes"));
178 }
179 }
180 }
181
182 aggregated_attrs
183 }
184
185 fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
186 fn try_insert<T>(
187 slot: &mut Option<Spanned<T>>,
188 value: T,
189 span: Span,
190 category: &'static str,
191 ) -> Result<(), MultipleAttrs> {
192 if let Some(prev) = slot {
193 Err(MultipleAttrs {
194 prev_span: prev.span,
195 category,
196 })
197 } else {
198 *slot = Some(Spanned { value, span });
199 Ok(())
200 }
201 }
202
203 use SpirvAttribute::*;
204 match attr {
205 IntrinsicType(value) => {
206 try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
207 }
208 Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
209 Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
210 StorageClass(value) => {
211 try_insert(&mut self.storage_class, value, span, "storage class")
212 }
213 Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
214 DescriptorSet(value) => try_insert(
215 &mut self.descriptor_set,
216 value,
217 span,
218 "#[spirv(descriptor_set)]",
219 ),
220 Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
221 Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
222 Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
223 Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
224 PerPrimitiveExt => try_insert(
225 &mut self.per_primitive_ext,
226 (),
227 span,
228 "#[spirv(per_primitive_ext)]",
229 ),
230 InputAttachmentIndex(value) => try_insert(
231 &mut self.input_attachment_index,
232 value,
233 span,
234 "#[spirv(attachment_index)]",
235 ),
236 SpecConstant(value) => try_insert(
237 &mut self.spec_constant,
238 value,
239 span,
240 "#[spirv(spec_constant)]",
241 ),
242 BufferLoadIntrinsic => try_insert(
243 &mut self.buffer_load_intrinsic,
244 (),
245 span,
246 "#[spirv(buffer_load_intrinsic)]",
247 ),
248 BufferStoreIntrinsic => try_insert(
249 &mut self.buffer_store_intrinsic,
250 (),
251 span,
252 "#[spirv(buffer_store_intrinsic)]",
253 ),
254 }
255 }
256}
257
258fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
260 match impl_item.kind {
261 hir::ImplItemKind::Const(..) => Target::AssocConst,
262 hir::ImplItemKind::Fn(..) => {
263 let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
264 let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
265 let containing_impl_is_for_trait = match &containing_item.kind {
266 hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
267 _ => unreachable!("parent of an ImplItem must be an Impl"),
268 };
269 if containing_impl_is_for_trait {
270 Target::Method(MethodKind::Trait { body: true })
271 } else {
272 Target::Method(MethodKind::Inherent)
273 }
274 }
275 hir::ImplItemKind::Type(..) => Target::AssocTy,
276 }
277}
278
279struct CheckSpirvAttrVisitor<'tcx> {
280 tcx: TyCtxt<'tcx>,
281 sym: Rc<Symbols>,
282}
283
284impl CheckSpirvAttrVisitor<'_> {
285 fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
286 let mut aggregated_attrs = AggregatedSpirvAttributes::default();
287
288 let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
289
290 let attrs = self.tcx.hir_attrs(hir_id);
291 for parse_attr_result in parse_attrs(attrs) {
292 let (span, parsed_attr) = match parse_attr_result {
293 Ok(span_and_parsed_attr) => span_and_parsed_attr,
294 Err((span, msg)) => {
295 self.tcx.dcx().span_err(span, msg);
296 continue;
297 }
298 };
299
300 struct Expected<T>(T);
302
303 let valid_target = match parsed_attr {
304 SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
305 Target::Struct => {
306 Ok(())
309 }
310
311 _ => Err(Expected("struct")),
312 },
313
314 SpirvAttribute::Entry(_) => match target {
315 Target::Fn
316 | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
317 Ok(())
320 }
321
322 _ => Err(Expected("function")),
323 },
324
325 SpirvAttribute::StorageClass(_)
326 | SpirvAttribute::Builtin(_)
327 | SpirvAttribute::DescriptorSet(_)
328 | SpirvAttribute::Binding(_)
329 | SpirvAttribute::Location(_)
330 | SpirvAttribute::Flat
331 | SpirvAttribute::Invariant
332 | SpirvAttribute::PerPrimitiveExt
333 | SpirvAttribute::InputAttachmentIndex(_)
334 | SpirvAttribute::SpecConstant(_) => match target {
335 Target::Param => {
336 let parent_hir_id = self.tcx.parent_hir_id(hir_id);
337 let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
338 .filter_map(|r| r.ok())
339 .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
340 if !parent_is_entry_point {
341 self.tcx.dcx().span_err(
342 span,
343 "attribute is only valid on a parameter of an entry-point function",
344 );
345 } else {
346 if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
349 let valid = match storage_class {
350 StorageClass::Input | StorageClass::Output => {
351 Err("is the default and should not be explicitly specified")
352 }
353
354 StorageClass::Private
355 | StorageClass::Function
356 | StorageClass::Generic => {
357 Err("can not be used as part of an entry's interface")
358 }
359
360 _ => Ok(()),
361 };
362
363 if let Err(msg) = valid {
364 self.tcx.dcx().span_err(
365 span,
366 format!("`{storage_class:?}` storage class {msg}"),
367 );
368 }
369 }
370 }
371 Ok(())
372 }
373
374 _ => Err(Expected("function parameter")),
375 },
376 SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
377 match target {
378 Target::Fn => Ok(()),
379 _ => Err(Expected("function")),
380 }
381 }
382 };
383 match valid_target {
384 Err(Expected(expected_target)) => {
385 self.tcx.dcx().span_err(
386 span,
387 format!(
388 "attribute is only valid on a {expected_target}, not on a {target}"
389 ),
390 );
391 }
392 Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
393 Ok(()) => {}
394 Err(MultipleAttrs {
395 prev_span,
396 category,
397 }) => {
398 self.tcx
399 .dcx()
400 .struct_span_err(
401 span,
402 format!("only one {category} attribute is allowed on a {target}"),
403 )
404 .with_span_note(prev_span, format!("previous {category} attribute"))
405 .emit();
406 }
407 },
408 }
409 }
410
411 if let Some(block_attr) = aggregated_attrs.block {
415 self.tcx.dcx().span_warn(
416 block_attr.span,
417 "#[spirv(block)] is no longer needed and should be removed",
418 );
419 }
420 }
421}
422
423impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
425 type NestedFilter = nested_filter::OnlyBodies;
426
427 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
428 self.tcx
429 }
430
431 fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
432 let target = Target::from_item(item);
433 self.check_spirv_attributes(item.hir_id(), target);
434 intravisit::walk_item(self, item);
435 }
436
437 fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
438 let target = Target::from_generic_param(generic_param);
439 self.check_spirv_attributes(generic_param.hir_id, target);
440 intravisit::walk_generic_param(self, generic_param);
441 }
442
443 fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
444 let target = Target::from_trait_item(trait_item);
445 self.check_spirv_attributes(trait_item.hir_id(), target);
446 intravisit::walk_trait_item(self, trait_item);
447 }
448
449 fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
450 self.check_spirv_attributes(field.hir_id, Target::Field);
451 intravisit::walk_field_def(self, field);
452 }
453
454 fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
455 self.check_spirv_attributes(arm.hir_id, Target::Arm);
456 intravisit::walk_arm(self, arm);
457 }
458
459 fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
460 let target = Target::from_foreign_item(f_item);
461 self.check_spirv_attributes(f_item.hir_id(), target);
462 intravisit::walk_foreign_item(self, f_item);
463 }
464
465 fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
466 let target = target_from_impl_item(self.tcx, impl_item);
467 self.check_spirv_attributes(impl_item.hir_id(), target);
468 intravisit::walk_impl_item(self, impl_item);
469 }
470
471 fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
472 if let hir::StmtKind::Let(l) = stmt.kind {
474 self.check_spirv_attributes(l.hir_id, Target::Statement);
475 }
476 intravisit::walk_stmt(self, stmt);
477 }
478
479 fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
480 let target = match expr.kind {
481 hir::ExprKind::Closure { .. } => Target::Closure,
482 _ => Target::Expression,
483 };
484
485 self.check_spirv_attributes(expr.hir_id, target);
486 intravisit::walk_expr(self, expr);
487 }
488
489 fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
490 self.check_spirv_attributes(variant.hir_id, Target::Variant);
491 intravisit::walk_variant(self, variant);
492 }
493
494 fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
495 self.check_spirv_attributes(param.hir_id, Target::Param);
496
497 intravisit::walk_param(self, param);
498 }
499}
500
501fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
503 let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
504 tcx,
505 sym: Symbols::get(),
506 };
507 tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
508 if module_def_id.is_top_level_module() {
509 check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
510 }
511}
512
513pub(crate) fn provide(providers: &mut Providers) {
514 *providers = Providers {
515 check_mod_attrs: |tcx, module_def_id| {
516 (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
518 check_mod_attrs(tcx, module_def_id);
519 },
520 ..*providers
521 };
522}
523
524type ParseAttrError = (Span, String);
526
527#[allow(clippy::get_first)]
528fn parse_attrs_for_checking<'a>(
529 sym: &'a Symbols,
530 attrs: &'a [Attribute],
531) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
532 attrs
533 .iter()
534 .map(move |attr| {
535 match attr {
537 Attribute::Unparsed(item) => {
538 let s = &item.path.segments;
540 if let Some(rust_gpu) = s.get(0) && rust_gpu.name == sym.rust_gpu {
541 match s.get(1) {
543 Some(command) if command.name == sym.spirv_attr_with_version => {
544 if let Some(args) = attr.meta_item_list() {
546 Ok(parse_spirv_attr(sym, args.iter()))
548 } else {
549 Err((
551 attr.span(),
552 "#[spirv(..)] attribute must have at least one argument"
553 .to_string(),
554 ))
555 }
556 }
557 Some(command) if command.name == sym.vector => {
558 match s.get(2) {
560 Some(version) if version.name == sym.v1 => {
562 Ok(SmallVec::from_iter([
563 Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
564 ]))
565 },
566 _ => Err((
567 attr.span(),
568 "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
569 .to_string(),
570 )),
571 }
572 }
573 _ => {
574 let spirv = sym.spirv_attr_with_version.as_str();
576 Err((
577 attr.span(),
578 format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
579 Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
580 ))
581 }
582 }
583 } else {
584 Ok(Default::default())
586 }
587 }
588 Attribute::Parsed(_) => Ok(Default::default()),
589 }
590 })
591 .flat_map(|result| {
592 result
593 .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
594 .into_iter()
595 })
596}
597
598fn parse_spirv_attr<'a>(
599 sym: &Symbols,
600 iter: impl Iterator<Item = &'a MetaItemInner>,
601) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
602 iter.map(|arg| {
603 let span = arg.span();
604 let parsed_attr =
605 if arg.has_name(sym.descriptor_set) {
606 SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
607 } else if arg.has_name(sym.binding) {
608 SpirvAttribute::Binding(parse_attr_int_value(arg)?)
609 } else if arg.has_name(sym.location) {
610 SpirvAttribute::Location(parse_attr_int_value(arg)?)
611 } else if arg.has_name(sym.input_attachment_index) {
612 SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
613 } else if arg.has_name(sym.spec_constant) {
614 SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
615 } else {
616 let name = match arg.ident() {
617 Some(i) => i,
618 None => {
619 return Err((
620 span,
621 "#[spirv(..)] attribute argument must be single identifier".to_string(),
622 ));
623 }
624 };
625 sym.attributes.get(&name.name).map_or_else(
626 || Err((name.span, "unknown argument to spirv attribute".to_string())),
627 |a| {
628 Ok(match a {
629 SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
630 parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
631 ),
632 _ => a.clone(),
633 })
634 },
635 )?
636 };
637 Ok((span, parsed_attr))
638 })
639 .collect()
640}
641
642fn parse_spec_constant_attr(
643 sym: &Symbols,
644 arg: &MetaItemInner,
645) -> Result<SpecConstant, ParseAttrError> {
646 let mut id = None;
647 let mut default = None;
648
649 if let Some(attrs) = arg.meta_item_list() {
650 for attr in attrs {
651 if attr.has_name(sym.id) {
652 if id.is_none() {
653 id = Some(parse_attr_int_value(attr)?);
654 } else {
655 return Err((attr.span(), "`id` may only be specified once".into()));
656 }
657 } else if attr.has_name(sym.default) {
658 if default.is_none() {
659 default = Some(parse_attr_int_value(attr)?);
660 } else {
661 return Err((attr.span(), "`default` may only be specified once".into()));
662 }
663 } else {
664 return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
665 }
666 }
667 }
668 Ok(SpecConstant {
669 id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
670 default,
671 array_count: None,
673 })
674}
675
676fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
677 let arg = match arg.meta_item() {
678 Some(arg) => arg,
679 None => return Err((arg.span(), "attribute must have value".to_string())),
680 };
681 match arg.name_value_literal() {
682 Some(&MetaItemLit {
683 kind: LitKind::Int(x, ..),
684 ..
685 }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
686 _ => Err((arg.span, "attribute value must be integer".to_string())),
687 }
688}
689
690fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
691 let arg = match arg.meta_item() {
692 Some(arg) => arg,
693 None => return Err((arg.span(), "attribute must have value".to_string())),
694 };
695 match arg.meta_item_list() {
696 Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
697 let mut local_size = [1; 3];
698 for (idx, lit) in tuple.iter().enumerate() {
699 match lit {
700 MetaItemInner::Lit(MetaItemLit {
701 kind: LitKind::Int(x, ..),
702 ..
703 }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
704 _ => return Err((lit.span(), "must be a u32 literal".to_string())),
705 }
706 }
707 Ok(local_size)
708 }
709 Some([]) => Err((
710 arg.span,
711 "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
712 )),
713 Some(tuple) if tuple.len() > 3 => Err((
714 arg.span,
715 "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
716 )),
717 _ => Err((
718 arg.span,
719 "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
720 )),
721 }
722}
723
724fn parse_entry_attrs(
729 sym: &Symbols,
730 arg: &MetaItemInner,
731 name: &Ident,
732 execution_model: ExecutionModel,
733) -> Result<Entry, ParseAttrError> {
734 use ExecutionMode::*;
735 use ExecutionModel::*;
736 let mut entry = Entry::from(execution_model);
737 let mut origin_mode: Option<ExecutionMode> = None;
738 let mut local_size: Option<[u32; 3]> = None;
739 let mut local_size_hint: Option<[u32; 3]> = None;
740 if let Some(attrs) = arg.meta_item_list() {
743 for attr in attrs {
744 if let Some(attr_name) = attr.ident() {
745 if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
746 {
747 use crate::symbols::ExecutionModeExtraDim::*;
748 let val = match extra_dim {
749 None | Tuple => Option::None,
750 _ => Some(parse_attr_int_value(attr)?),
751 };
752 match execution_mode {
753 OriginUpperLeft | OriginLowerLeft => {
754 origin_mode.replace(*execution_mode);
755 }
756 LocalSize => {
757 if local_size.is_none() {
758 local_size.replace(parse_local_size_attr(attr)?);
759 } else {
760 return Err((
761 attr_name.span,
762 String::from(
763 "`#[spirv(compute(threads))]` may only be specified once",
764 ),
765 ));
766 }
767 }
768 LocalSizeHint => {
769 let val = val.unwrap();
770 if local_size_hint.is_none() {
771 local_size_hint.replace([1, 1, 1]);
772 }
773 let local_size_hint = local_size_hint.as_mut().unwrap();
774 match extra_dim {
775 X => {
776 local_size_hint[0] = val;
777 }
778 Y => {
779 local_size_hint[1] = val;
780 }
781 Z => {
782 local_size_hint[2] = val;
783 }
784 _ => unreachable!(),
785 }
786 }
787 _ => {
809 if let Some(val) = val {
810 entry
811 .execution_modes
812 .push((*execution_mode, ExecutionModeExtra::new([val])));
813 } else {
814 entry
815 .execution_modes
816 .push((*execution_mode, ExecutionModeExtra::new([])));
817 }
818 }
819 }
820 } else if attr_name.name == sym.entry_point_name {
821 match attr.value_str() {
822 Some(sym) => {
823 entry.name = Some(sym);
824 }
825 None => {
826 return Err((
827 attr_name.span,
828 format!(
829 "#[spirv({name}(..))] unknown attribute argument {attr_name}"
830 ),
831 ));
832 }
833 }
834 } else {
835 return Err((
836 attr_name.span,
837 format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
838 ));
839 }
840 } else {
841 return Err((
842 arg.span(),
843 format!("#[spirv({name}(..))] attribute argument must be single identifier"),
844 ));
845 }
846 }
847 }
848 match entry.execution_model {
849 Fragment => {
850 let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
851 entry
852 .execution_modes
853 .push((origin_mode, ExecutionModeExtra::new([])));
854 }
855 GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
856 if let Some(local_size) = local_size {
857 entry
858 .execution_modes
859 .push((LocalSize, ExecutionModeExtra::new(local_size)));
860 } else {
861 return Err((
862 arg.span(),
863 String::from(
864 "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
865 ),
866 ));
867 }
868 }
869 _ => {}
871 }
872 Ok(entry)
873}