1use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23 args: [u32; 3],
24 len: u8,
25}
26
27impl ExecutionModeExtra {
28 pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29 let _args = args.as_ref();
30 let mut args = [0; 3];
31 args[.._args.len()].copy_from_slice(_args);
32 let len = _args.len() as u8;
33 Self { args, len }
34 }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38 fn as_ref(&self) -> &[u32] {
39 &self.args[..self.len as _]
40 }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45 pub execution_model: ExecutionModel,
46 pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47 pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51 fn from(execution_model: ExecutionModel) -> Self {
52 Self {
53 execution_model,
54 execution_modes: Vec::new(),
55 name: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63 GenericImageType,
64 Sampler,
65 AccelerationStructureKhr,
66 SampledImage,
67 RayQueryKhr,
68 RuntimeArray,
69 TypedBuffer,
70 Matrix,
71 Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76 pub id: u32,
77 pub default: Option<u32>,
78 pub array_count: Option<u32>,
79}
80
81#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85 IntrinsicType(IntrinsicType),
87 Block,
88
89 Entry(Entry),
91
92 StorageClass(StorageClass),
94 Builtin(BuiltIn),
95 DescriptorSet(u32),
96 Binding(u32),
97 Flat,
98 PerPrimitiveExt,
99 Invariant,
100 InputAttachmentIndex(u32),
101 SpecConstant(SpecConstant),
102
103 BufferLoadIntrinsic,
105 BufferStoreIntrinsic,
106}
107
108#[derive(Copy, Clone)]
111pub struct Spanned<T> {
112 pub value: T,
113 pub span: Span,
114}
115
116#[derive(Default)]
120pub struct AggregatedSpirvAttributes {
121 pub intrinsic_type: Option<Spanned<IntrinsicType>>,
123 pub block: Option<Spanned<()>>,
124
125 pub entry: Option<Spanned<Entry>>,
127
128 pub storage_class: Option<Spanned<StorageClass>>,
130 pub builtin: Option<Spanned<BuiltIn>>,
131 pub descriptor_set: Option<Spanned<u32>>,
132 pub binding: Option<Spanned<u32>>,
133 pub flat: Option<Spanned<()>>,
134 pub invariant: Option<Spanned<()>>,
135 pub per_primitive_ext: Option<Spanned<()>>,
136 pub input_attachment_index: Option<Spanned<u32>>,
137 pub spec_constant: Option<Spanned<SpecConstant>>,
138
139 pub buffer_load_intrinsic: Option<Spanned<()>>,
141 pub buffer_store_intrinsic: Option<Spanned<()>>,
142}
143
144struct MultipleAttrs {
145 prev_span: Span,
146 category: &'static str,
147}
148
149impl AggregatedSpirvAttributes {
150 pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
155 let mut aggregated_attrs = Self::default();
156
157 for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
160 let (span, parsed_attr) = match parse_attr_result {
161 Ok(span_and_parsed_attr) => span_and_parsed_attr,
162 Err((span, msg)) => {
163 cx.tcx.dcx().span_delayed_bug(span, msg);
164 continue;
165 }
166 };
167 match aggregated_attrs.try_insert_attr(parsed_attr, span) {
168 Ok(()) => {}
169 Err(MultipleAttrs {
170 prev_span: _,
171 category,
172 }) => {
173 cx.tcx
174 .dcx()
175 .span_delayed_bug(span, format!("multiple {category} attributes"));
176 }
177 }
178 }
179
180 aggregated_attrs
181 }
182
183 fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
184 fn try_insert<T>(
185 slot: &mut Option<Spanned<T>>,
186 value: T,
187 span: Span,
188 category: &'static str,
189 ) -> Result<(), MultipleAttrs> {
190 if let Some(prev) = slot {
191 Err(MultipleAttrs {
192 prev_span: prev.span,
193 category,
194 })
195 } else {
196 *slot = Some(Spanned { value, span });
197 Ok(())
198 }
199 }
200
201 use SpirvAttribute::*;
202 match attr {
203 IntrinsicType(value) => {
204 try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
205 }
206 Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
207 Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
208 StorageClass(value) => {
209 try_insert(&mut self.storage_class, value, span, "storage class")
210 }
211 Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
212 DescriptorSet(value) => try_insert(
213 &mut self.descriptor_set,
214 value,
215 span,
216 "#[spirv(descriptor_set)]",
217 ),
218 Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
219 Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
220 Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
221 PerPrimitiveExt => try_insert(
222 &mut self.per_primitive_ext,
223 (),
224 span,
225 "#[spirv(per_primitive_ext)]",
226 ),
227 InputAttachmentIndex(value) => try_insert(
228 &mut self.input_attachment_index,
229 value,
230 span,
231 "#[spirv(attachment_index)]",
232 ),
233 SpecConstant(value) => try_insert(
234 &mut self.spec_constant,
235 value,
236 span,
237 "#[spirv(spec_constant)]",
238 ),
239 BufferLoadIntrinsic => try_insert(
240 &mut self.buffer_load_intrinsic,
241 (),
242 span,
243 "#[spirv(buffer_load_intrinsic)]",
244 ),
245 BufferStoreIntrinsic => try_insert(
246 &mut self.buffer_store_intrinsic,
247 (),
248 span,
249 "#[spirv(buffer_store_intrinsic)]",
250 ),
251 }
252 }
253}
254
255fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
257 match impl_item.kind {
258 hir::ImplItemKind::Const(..) => Target::AssocConst,
259 hir::ImplItemKind::Fn(..) => {
260 let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
261 let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
262 let containing_impl_is_for_trait = match &containing_item.kind {
263 hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
264 _ => unreachable!("parent of an ImplItem must be an Impl"),
265 };
266 if containing_impl_is_for_trait {
267 Target::Method(MethodKind::Trait { body: true })
268 } else {
269 Target::Method(MethodKind::Inherent)
270 }
271 }
272 hir::ImplItemKind::Type(..) => Target::AssocTy,
273 }
274}
275
276struct CheckSpirvAttrVisitor<'tcx> {
277 tcx: TyCtxt<'tcx>,
278 sym: Rc<Symbols>,
279}
280
281impl CheckSpirvAttrVisitor<'_> {
282 fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
283 let mut aggregated_attrs = AggregatedSpirvAttributes::default();
284
285 let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
286
287 let attrs = self.tcx.hir_attrs(hir_id);
288 for parse_attr_result in parse_attrs(attrs) {
289 let (span, parsed_attr) = match parse_attr_result {
290 Ok(span_and_parsed_attr) => span_and_parsed_attr,
291 Err((span, msg)) => {
292 self.tcx.dcx().span_err(span, msg);
293 continue;
294 }
295 };
296
297 struct Expected<T>(T);
299
300 let valid_target = match parsed_attr {
301 SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
302 Target::Struct => {
303 Ok(())
306 }
307
308 _ => Err(Expected("struct")),
309 },
310
311 SpirvAttribute::Entry(_) => match target {
312 Target::Fn
313 | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
314 Ok(())
317 }
318
319 _ => Err(Expected("function")),
320 },
321
322 SpirvAttribute::StorageClass(_)
323 | SpirvAttribute::Builtin(_)
324 | SpirvAttribute::DescriptorSet(_)
325 | SpirvAttribute::Binding(_)
326 | SpirvAttribute::Flat
327 | SpirvAttribute::Invariant
328 | SpirvAttribute::PerPrimitiveExt
329 | SpirvAttribute::InputAttachmentIndex(_)
330 | SpirvAttribute::SpecConstant(_) => match target {
331 Target::Param => {
332 let parent_hir_id = self.tcx.parent_hir_id(hir_id);
333 let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
334 .filter_map(|r| r.ok())
335 .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
336 if !parent_is_entry_point {
337 self.tcx.dcx().span_err(
338 span,
339 "attribute is only valid on a parameter of an entry-point function",
340 );
341 } else {
342 if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
345 let valid = match storage_class {
346 StorageClass::Input | StorageClass::Output => {
347 Err("is the default and should not be explicitly specified")
348 }
349
350 StorageClass::Private
351 | StorageClass::Function
352 | StorageClass::Generic => {
353 Err("can not be used as part of an entry's interface")
354 }
355
356 _ => Ok(()),
357 };
358
359 if let Err(msg) = valid {
360 self.tcx.dcx().span_err(
361 span,
362 format!("`{storage_class:?}` storage class {msg}"),
363 );
364 }
365 }
366 }
367 Ok(())
368 }
369
370 _ => Err(Expected("function parameter")),
371 },
372 SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
373 match target {
374 Target::Fn => Ok(()),
375 _ => Err(Expected("function")),
376 }
377 }
378 };
379 match valid_target {
380 Err(Expected(expected_target)) => {
381 self.tcx.dcx().span_err(
382 span,
383 format!(
384 "attribute is only valid on a {expected_target}, not on a {target}"
385 ),
386 );
387 }
388 Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
389 Ok(()) => {}
390 Err(MultipleAttrs {
391 prev_span,
392 category,
393 }) => {
394 self.tcx
395 .dcx()
396 .struct_span_err(
397 span,
398 format!("only one {category} attribute is allowed on a {target}"),
399 )
400 .with_span_note(prev_span, format!("previous {category} attribute"))
401 .emit();
402 }
403 },
404 }
405 }
406
407 if let Some(block_attr) = aggregated_attrs.block {
411 self.tcx.dcx().span_warn(
412 block_attr.span,
413 "#[spirv(block)] is no longer needed and should be removed",
414 );
415 }
416 }
417}
418
419impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
421 type NestedFilter = nested_filter::OnlyBodies;
422
423 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
424 self.tcx
425 }
426
427 fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
428 let target = Target::from_item(item);
429 self.check_spirv_attributes(item.hir_id(), target);
430 intravisit::walk_item(self, item);
431 }
432
433 fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
434 let target = Target::from_generic_param(generic_param);
435 self.check_spirv_attributes(generic_param.hir_id, target);
436 intravisit::walk_generic_param(self, generic_param);
437 }
438
439 fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
440 let target = Target::from_trait_item(trait_item);
441 self.check_spirv_attributes(trait_item.hir_id(), target);
442 intravisit::walk_trait_item(self, trait_item);
443 }
444
445 fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
446 self.check_spirv_attributes(field.hir_id, Target::Field);
447 intravisit::walk_field_def(self, field);
448 }
449
450 fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
451 self.check_spirv_attributes(arm.hir_id, Target::Arm);
452 intravisit::walk_arm(self, arm);
453 }
454
455 fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
456 let target = Target::from_foreign_item(f_item);
457 self.check_spirv_attributes(f_item.hir_id(), target);
458 intravisit::walk_foreign_item(self, f_item);
459 }
460
461 fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
462 let target = target_from_impl_item(self.tcx, impl_item);
463 self.check_spirv_attributes(impl_item.hir_id(), target);
464 intravisit::walk_impl_item(self, impl_item);
465 }
466
467 fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
468 if let hir::StmtKind::Let(l) = stmt.kind {
470 self.check_spirv_attributes(l.hir_id, Target::Statement);
471 }
472 intravisit::walk_stmt(self, stmt);
473 }
474
475 fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
476 let target = match expr.kind {
477 hir::ExprKind::Closure { .. } => Target::Closure,
478 _ => Target::Expression,
479 };
480
481 self.check_spirv_attributes(expr.hir_id, target);
482 intravisit::walk_expr(self, expr);
483 }
484
485 fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
486 self.check_spirv_attributes(variant.hir_id, Target::Variant);
487 intravisit::walk_variant(self, variant);
488 }
489
490 fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
491 self.check_spirv_attributes(param.hir_id, Target::Param);
492
493 intravisit::walk_param(self, param);
494 }
495}
496
497fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
499 let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
500 tcx,
501 sym: Symbols::get(),
502 };
503 tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
504 if module_def_id.is_top_level_module() {
505 check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
506 }
507}
508
509pub(crate) fn provide(providers: &mut Providers) {
510 *providers = Providers {
511 check_mod_attrs: |tcx, module_def_id| {
512 (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
514 check_mod_attrs(tcx, module_def_id);
515 },
516 ..*providers
517 };
518}
519
520type ParseAttrError = (Span, String);
522
523#[allow(clippy::get_first)]
524fn parse_attrs_for_checking<'a>(
525 sym: &'a Symbols,
526 attrs: &'a [Attribute],
527) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
528 attrs
529 .iter()
530 .map(move |attr| {
531 match attr {
533 Attribute::Unparsed(item) => {
534 let s = &item.path.segments;
536 if let Some(rust_gpu) = s.get(0) && rust_gpu.name == sym.rust_gpu {
537 match s.get(1) {
539 Some(command) if command.name == sym.spirv_attr_with_version => {
540 if let Some(args) = attr.meta_item_list() {
542 Ok(parse_spirv_attr(sym, args.iter()))
544 } else {
545 Err((
547 attr.span(),
548 "#[spirv(..)] attribute must have at least one argument"
549 .to_string(),
550 ))
551 }
552 }
553 Some(command) if command.name == sym.vector => {
554 match s.get(2) {
556 Some(version) if version.name == sym.v1 => {
558 Ok(SmallVec::from_iter([
559 Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
560 ]))
561 },
562 _ => Err((
563 attr.span(),
564 "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
565 .to_string(),
566 )),
567 }
568 }
569 _ => {
570 let spirv = sym.spirv_attr_with_version.as_str();
572 Err((
573 attr.span(),
574 format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
575 Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
576 ))
577 }
578 }
579 } else {
580 Ok(Default::default())
582 }
583 }
584 Attribute::Parsed(_) => Ok(Default::default()),
585 }
586 })
587 .flat_map(|result| {
588 result
589 .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
590 .into_iter()
591 })
592}
593
594fn parse_spirv_attr<'a>(
595 sym: &Symbols,
596 iter: impl Iterator<Item = &'a MetaItemInner>,
597) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
598 iter.map(|arg| {
599 let span = arg.span();
600 let parsed_attr =
601 if arg.has_name(sym.descriptor_set) {
602 SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
603 } else if arg.has_name(sym.binding) {
604 SpirvAttribute::Binding(parse_attr_int_value(arg)?)
605 } else if arg.has_name(sym.input_attachment_index) {
606 SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
607 } else if arg.has_name(sym.spec_constant) {
608 SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
609 } else {
610 let name = match arg.ident() {
611 Some(i) => i,
612 None => {
613 return Err((
614 span,
615 "#[spirv(..)] attribute argument must be single identifier".to_string(),
616 ));
617 }
618 };
619 sym.attributes.get(&name.name).map_or_else(
620 || Err((name.span, "unknown argument to spirv attribute".to_string())),
621 |a| {
622 Ok(match a {
623 SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
624 parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
625 ),
626 _ => a.clone(),
627 })
628 },
629 )?
630 };
631 Ok((span, parsed_attr))
632 })
633 .collect()
634}
635
636fn parse_spec_constant_attr(
637 sym: &Symbols,
638 arg: &MetaItemInner,
639) -> Result<SpecConstant, ParseAttrError> {
640 let mut id = None;
641 let mut default = None;
642
643 if let Some(attrs) = arg.meta_item_list() {
644 for attr in attrs {
645 if attr.has_name(sym.id) {
646 if id.is_none() {
647 id = Some(parse_attr_int_value(attr)?);
648 } else {
649 return Err((attr.span(), "`id` may only be specified once".into()));
650 }
651 } else if attr.has_name(sym.default) {
652 if default.is_none() {
653 default = Some(parse_attr_int_value(attr)?);
654 } else {
655 return Err((attr.span(), "`default` may only be specified once".into()));
656 }
657 } else {
658 return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
659 }
660 }
661 }
662 Ok(SpecConstant {
663 id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
664 default,
665 array_count: None,
667 })
668}
669
670fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
671 let arg = match arg.meta_item() {
672 Some(arg) => arg,
673 None => return Err((arg.span(), "attribute must have value".to_string())),
674 };
675 match arg.name_value_literal() {
676 Some(&MetaItemLit {
677 kind: LitKind::Int(x, ..),
678 ..
679 }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
680 _ => Err((arg.span, "attribute value must be integer".to_string())),
681 }
682}
683
684fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
685 let arg = match arg.meta_item() {
686 Some(arg) => arg,
687 None => return Err((arg.span(), "attribute must have value".to_string())),
688 };
689 match arg.meta_item_list() {
690 Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
691 let mut local_size = [1; 3];
692 for (idx, lit) in tuple.iter().enumerate() {
693 match lit {
694 MetaItemInner::Lit(MetaItemLit {
695 kind: LitKind::Int(x, ..),
696 ..
697 }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
698 _ => return Err((lit.span(), "must be a u32 literal".to_string())),
699 }
700 }
701 Ok(local_size)
702 }
703 Some([]) => Err((
704 arg.span,
705 "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
706 )),
707 Some(tuple) if tuple.len() > 3 => Err((
708 arg.span,
709 "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
710 )),
711 _ => Err((
712 arg.span,
713 "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
714 )),
715 }
716}
717
718fn parse_entry_attrs(
723 sym: &Symbols,
724 arg: &MetaItemInner,
725 name: &Ident,
726 execution_model: ExecutionModel,
727) -> Result<Entry, ParseAttrError> {
728 use ExecutionMode::*;
729 use ExecutionModel::*;
730 let mut entry = Entry::from(execution_model);
731 let mut origin_mode: Option<ExecutionMode> = None;
732 let mut local_size: Option<[u32; 3]> = None;
733 let mut local_size_hint: Option<[u32; 3]> = None;
734 if let Some(attrs) = arg.meta_item_list() {
737 for attr in attrs {
738 if let Some(attr_name) = attr.ident() {
739 if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
740 {
741 use crate::symbols::ExecutionModeExtraDim::*;
742 let val = match extra_dim {
743 None | Tuple => Option::None,
744 _ => Some(parse_attr_int_value(attr)?),
745 };
746 match execution_mode {
747 OriginUpperLeft | OriginLowerLeft => {
748 origin_mode.replace(*execution_mode);
749 }
750 LocalSize => {
751 if local_size.is_none() {
752 local_size.replace(parse_local_size_attr(attr)?);
753 } else {
754 return Err((
755 attr_name.span,
756 String::from(
757 "`#[spirv(compute(threads))]` may only be specified once",
758 ),
759 ));
760 }
761 }
762 LocalSizeHint => {
763 let val = val.unwrap();
764 if local_size_hint.is_none() {
765 local_size_hint.replace([1, 1, 1]);
766 }
767 let local_size_hint = local_size_hint.as_mut().unwrap();
768 match extra_dim {
769 X => {
770 local_size_hint[0] = val;
771 }
772 Y => {
773 local_size_hint[1] = val;
774 }
775 Z => {
776 local_size_hint[2] = val;
777 }
778 _ => unreachable!(),
779 }
780 }
781 _ => {
803 if let Some(val) = val {
804 entry
805 .execution_modes
806 .push((*execution_mode, ExecutionModeExtra::new([val])));
807 } else {
808 entry
809 .execution_modes
810 .push((*execution_mode, ExecutionModeExtra::new([])));
811 }
812 }
813 }
814 } else if attr_name.name == sym.entry_point_name {
815 match attr.value_str() {
816 Some(sym) => {
817 entry.name = Some(sym);
818 }
819 None => {
820 return Err((
821 attr_name.span,
822 format!(
823 "#[spirv({name}(..))] unknown attribute argument {attr_name}"
824 ),
825 ));
826 }
827 }
828 } else {
829 return Err((
830 attr_name.span,
831 format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
832 ));
833 }
834 } else {
835 return Err((
836 arg.span(),
837 format!("#[spirv({name}(..))] attribute argument must be single identifier"),
838 ));
839 }
840 }
841 }
842 match entry.execution_model {
843 Fragment => {
844 let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
845 entry
846 .execution_modes
847 .push((origin_mode, ExecutionModeExtra::new([])));
848 }
849 GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
850 if let Some(local_size) = local_size {
851 entry
852 .execution_modes
853 .push((LocalSize, ExecutionModeExtra::new(local_size)));
854 } else {
855 return Err((
856 arg.span(),
857 String::from(
858 "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
859 ),
860 ));
861 }
862 }
863 _ => {}
865 }
866 Ok(entry)
867}