rustc_codegen_spirv/
attr.rs

1//! `#[spirv(...)]` attribute support.
2//!
3//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
4
5use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
21#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23    args: [u32; 3],
24    len: u8,
25}
26
27impl ExecutionModeExtra {
28    pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29        let _args = args.as_ref();
30        let mut args = [0; 3];
31        args[.._args.len()].copy_from_slice(_args);
32        let len = _args.len() as u8;
33        Self { args, len }
34    }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38    fn as_ref(&self) -> &[u32] {
39        &self.args[..self.len as _]
40    }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45    pub execution_model: ExecutionModel,
46    pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47    pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51    fn from(execution_model: ExecutionModel) -> Self {
52        Self {
53            execution_model,
54            execution_modes: Vec::new(),
55            name: None,
56        }
57    }
58}
59
60/// `struct` types that are used to represent special SPIR-V types.
61#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63    GenericImageType,
64    Sampler,
65    AccelerationStructureKhr,
66    SampledImage,
67    RayQueryKhr,
68    RuntimeArray,
69    TypedBuffer,
70    Matrix,
71    Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76    pub id: u32,
77    pub default: Option<u32>,
78    pub array_count: Option<u32>,
79}
80
81// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
82// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
83#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85    // `struct` attributes:
86    IntrinsicType(IntrinsicType),
87    Block,
88
89    // `fn` attributes:
90    Entry(Entry),
91
92    // (entry) `fn` parameter attributes:
93    StorageClass(StorageClass),
94    Builtin(BuiltIn),
95    DescriptorSet(u32),
96    Binding(u32),
97    Location(u32),
98    Flat,
99    PerPrimitiveExt,
100    Invariant,
101    InputAttachmentIndex(u32),
102    SpecConstant(SpecConstant),
103
104    // `fn`/closure attributes:
105    BufferLoadIntrinsic,
106    BufferStoreIntrinsic,
107}
108
109// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
110// field name instead of `node` (which feels inadequate in this context).
111#[derive(Copy, Clone)]
112pub struct Spanned<T> {
113    pub value: T,
114    pub span: Span,
115}
116
117/// Condensed version of a `SpirvAttribute` list, but only keeping one value per
118/// variant of `SpirvAttribute`, and treating multiple such attributes an error.
119// FIXME(eddyb) should this and `fn try_insert_attr` below be generated by a macro?
120#[derive(Default)]
121pub struct AggregatedSpirvAttributes {
122    // `struct` attributes:
123    pub intrinsic_type: Option<Spanned<IntrinsicType>>,
124    pub block: Option<Spanned<()>>,
125
126    // `fn` attributes:
127    pub entry: Option<Spanned<Entry>>,
128
129    // (entry) `fn` parameter attributes:
130    pub storage_class: Option<Spanned<StorageClass>>,
131    pub builtin: Option<Spanned<BuiltIn>>,
132    pub descriptor_set: Option<Spanned<u32>>,
133    pub binding: Option<Spanned<u32>>,
134    pub location: Option<Spanned<u32>>,
135    pub flat: Option<Spanned<()>>,
136    pub invariant: Option<Spanned<()>>,
137    pub per_primitive_ext: Option<Spanned<()>>,
138    pub input_attachment_index: Option<Spanned<u32>>,
139    pub spec_constant: Option<Spanned<SpecConstant>>,
140
141    // `fn`/closure attributes:
142    pub buffer_load_intrinsic: Option<Spanned<()>>,
143    pub buffer_store_intrinsic: Option<Spanned<()>>,
144}
145
146struct MultipleAttrs {
147    prev_span: Span,
148    category: &'static str,
149}
150
151impl AggregatedSpirvAttributes {
152    /// Compute `AggregatedSpirvAttributes` for use during codegen.
153    ///
154    /// Any errors for malformed/duplicate attributes will have been reported
155    /// prior to codegen, by the `attr` check pass.
156    pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
157        let mut aggregated_attrs = Self::default();
158
159        // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
160        // to see an attribute error, it will cause an ICE instead.
161        for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
162            let (span, parsed_attr) = match parse_attr_result {
163                Ok(span_and_parsed_attr) => span_and_parsed_attr,
164                Err((span, msg)) => {
165                    cx.tcx.dcx().span_delayed_bug(span, msg);
166                    continue;
167                }
168            };
169            match aggregated_attrs.try_insert_attr(parsed_attr, span) {
170                Ok(()) => {}
171                Err(MultipleAttrs {
172                    prev_span: _,
173                    category,
174                }) => {
175                    cx.tcx
176                        .dcx()
177                        .span_delayed_bug(span, format!("multiple {category} attributes"));
178                }
179            }
180        }
181
182        aggregated_attrs
183    }
184
185    fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
186        fn try_insert<T>(
187            slot: &mut Option<Spanned<T>>,
188            value: T,
189            span: Span,
190            category: &'static str,
191        ) -> Result<(), MultipleAttrs> {
192            if let Some(prev) = slot {
193                Err(MultipleAttrs {
194                    prev_span: prev.span,
195                    category,
196                })
197            } else {
198                *slot = Some(Spanned { value, span });
199                Ok(())
200            }
201        }
202
203        use SpirvAttribute::*;
204        match attr {
205            IntrinsicType(value) => {
206                try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
207            }
208            Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
209            Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
210            StorageClass(value) => {
211                try_insert(&mut self.storage_class, value, span, "storage class")
212            }
213            Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
214            DescriptorSet(value) => try_insert(
215                &mut self.descriptor_set,
216                value,
217                span,
218                "#[spirv(descriptor_set)]",
219            ),
220            Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
221            Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
222            Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
223            Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
224            PerPrimitiveExt => try_insert(
225                &mut self.per_primitive_ext,
226                (),
227                span,
228                "#[spirv(per_primitive_ext)]",
229            ),
230            InputAttachmentIndex(value) => try_insert(
231                &mut self.input_attachment_index,
232                value,
233                span,
234                "#[spirv(attachment_index)]",
235            ),
236            SpecConstant(value) => try_insert(
237                &mut self.spec_constant,
238                value,
239                span,
240                "#[spirv(spec_constant)]",
241            ),
242            BufferLoadIntrinsic => try_insert(
243                &mut self.buffer_load_intrinsic,
244                (),
245                span,
246                "#[spirv(buffer_load_intrinsic)]",
247            ),
248            BufferStoreIntrinsic => try_insert(
249                &mut self.buffer_store_intrinsic,
250                (),
251                span,
252                "#[spirv(buffer_store_intrinsic)]",
253            ),
254        }
255    }
256}
257
258// FIXME(eddyb) make this reusable from somewhere in `rustc`.
259fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
260    match impl_item.kind {
261        hir::ImplItemKind::Const(..) => Target::AssocConst,
262        hir::ImplItemKind::Fn(..) => {
263            let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
264            let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
265            let containing_impl_is_for_trait = match &containing_item.kind {
266                hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
267                _ => unreachable!("parent of an ImplItem must be an Impl"),
268            };
269            if containing_impl_is_for_trait {
270                Target::Method(MethodKind::Trait { body: true })
271            } else {
272                Target::Method(MethodKind::Inherent)
273            }
274        }
275        hir::ImplItemKind::Type(..) => Target::AssocTy,
276    }
277}
278
279struct CheckSpirvAttrVisitor<'tcx> {
280    tcx: TyCtxt<'tcx>,
281    sym: Rc<Symbols>,
282}
283
284impl CheckSpirvAttrVisitor<'_> {
285    fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
286        let mut aggregated_attrs = AggregatedSpirvAttributes::default();
287
288        let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
289
290        let attrs = self.tcx.hir_attrs(hir_id);
291        for parse_attr_result in parse_attrs(attrs) {
292            let (span, parsed_attr) = match parse_attr_result {
293                Ok(span_and_parsed_attr) => span_and_parsed_attr,
294                Err((span, msg)) => {
295                    self.tcx.dcx().span_err(span, msg);
296                    continue;
297                }
298            };
299
300            /// Error newtype marker used below for readability.
301            struct Expected<T>(T);
302
303            let valid_target = match parsed_attr {
304                SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
305                    Target::Struct => {
306                        // FIXME(eddyb) further check type attribute validity,
307                        // e.g. layout, generics, other attributes, etc.
308                        Ok(())
309                    }
310
311                    _ => Err(Expected("struct")),
312                },
313
314                SpirvAttribute::Entry(_) => match target {
315                    Target::Fn
316                    | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
317                        // FIXME(eddyb) further check entry-point attribute validity,
318                        // e.g. signature, shouldn't have `#[inline]` or generics, etc.
319                        Ok(())
320                    }
321
322                    _ => Err(Expected("function")),
323                },
324
325                SpirvAttribute::StorageClass(_)
326                | SpirvAttribute::Builtin(_)
327                | SpirvAttribute::DescriptorSet(_)
328                | SpirvAttribute::Binding(_)
329                | SpirvAttribute::Location(_)
330                | SpirvAttribute::Flat
331                | SpirvAttribute::Invariant
332                | SpirvAttribute::PerPrimitiveExt
333                | SpirvAttribute::InputAttachmentIndex(_)
334                | SpirvAttribute::SpecConstant(_) => match target {
335                    Target::Param => {
336                        let parent_hir_id = self.tcx.parent_hir_id(hir_id);
337                        let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
338                            .filter_map(|r| r.ok())
339                            .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
340                        if !parent_is_entry_point {
341                            self.tcx.dcx().span_err(
342                                span,
343                                "attribute is only valid on a parameter of an entry-point function",
344                            );
345                        } else {
346                            // FIXME(eddyb) should we just remove all 5 of these storage class
347                            // attributes, instead of disallowing them here?
348                            if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
349                                let valid = match storage_class {
350                                    StorageClass::Input | StorageClass::Output => {
351                                        Err("is the default and should not be explicitly specified")
352                                    }
353
354                                    StorageClass::Private
355                                    | StorageClass::Function
356                                    | StorageClass::Generic => {
357                                        Err("can not be used as part of an entry's interface")
358                                    }
359
360                                    _ => Ok(()),
361                                };
362
363                                if let Err(msg) = valid {
364                                    self.tcx.dcx().span_err(
365                                        span,
366                                        format!("`{storage_class:?}` storage class {msg}"),
367                                    );
368                                }
369                            }
370                        }
371                        Ok(())
372                    }
373
374                    _ => Err(Expected("function parameter")),
375                },
376                SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
377                    match target {
378                        Target::Fn => Ok(()),
379                        _ => Err(Expected("function")),
380                    }
381                }
382            };
383            match valid_target {
384                Err(Expected(expected_target)) => {
385                    self.tcx.dcx().span_err(
386                        span,
387                        format!(
388                            "attribute is only valid on a {expected_target}, not on a {target}"
389                        ),
390                    );
391                }
392                Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
393                    Ok(()) => {}
394                    Err(MultipleAttrs {
395                        prev_span,
396                        category,
397                    }) => {
398                        self.tcx
399                            .dcx()
400                            .struct_span_err(
401                                span,
402                                format!("only one {category} attribute is allowed on a {target}"),
403                            )
404                            .with_span_note(prev_span, format!("previous {category} attribute"))
405                            .emit();
406                    }
407                },
408            }
409        }
410
411        // At this point we have all of the attributes (valid for this target),
412        // so we can perform further checks, emit warnings, etc.
413
414        if let Some(block_attr) = aggregated_attrs.block {
415            self.tcx.dcx().span_warn(
416                block_attr.span,
417                "#[spirv(block)] is no longer needed and should be removed",
418            );
419        }
420    }
421}
422
423// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
424impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
425    type NestedFilter = nested_filter::OnlyBodies;
426
427    fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
428        self.tcx
429    }
430
431    fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
432        let target = Target::from_item(item);
433        self.check_spirv_attributes(item.hir_id(), target);
434        intravisit::walk_item(self, item);
435    }
436
437    fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
438        let target = Target::from_generic_param(generic_param);
439        self.check_spirv_attributes(generic_param.hir_id, target);
440        intravisit::walk_generic_param(self, generic_param);
441    }
442
443    fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
444        let target = Target::from_trait_item(trait_item);
445        self.check_spirv_attributes(trait_item.hir_id(), target);
446        intravisit::walk_trait_item(self, trait_item);
447    }
448
449    fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
450        self.check_spirv_attributes(field.hir_id, Target::Field);
451        intravisit::walk_field_def(self, field);
452    }
453
454    fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
455        self.check_spirv_attributes(arm.hir_id, Target::Arm);
456        intravisit::walk_arm(self, arm);
457    }
458
459    fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
460        let target = Target::from_foreign_item(f_item);
461        self.check_spirv_attributes(f_item.hir_id(), target);
462        intravisit::walk_foreign_item(self, f_item);
463    }
464
465    fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
466        let target = target_from_impl_item(self.tcx, impl_item);
467        self.check_spirv_attributes(impl_item.hir_id(), target);
468        intravisit::walk_impl_item(self, impl_item);
469    }
470
471    fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
472        // When checking statements ignore expressions, they will be checked later.
473        if let hir::StmtKind::Let(l) = stmt.kind {
474            self.check_spirv_attributes(l.hir_id, Target::Statement);
475        }
476        intravisit::walk_stmt(self, stmt);
477    }
478
479    fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
480        let target = match expr.kind {
481            hir::ExprKind::Closure { .. } => Target::Closure,
482            _ => Target::Expression,
483        };
484
485        self.check_spirv_attributes(expr.hir_id, target);
486        intravisit::walk_expr(self, expr);
487    }
488
489    fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
490        self.check_spirv_attributes(variant.hir_id, Target::Variant);
491        intravisit::walk_variant(self, variant);
492    }
493
494    fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
495        self.check_spirv_attributes(param.hir_id, Target::Param);
496
497        intravisit::walk_param(self, param);
498    }
499}
500
501// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
502fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
503    let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
504        tcx,
505        sym: Symbols::get(),
506    };
507    tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
508    if module_def_id.is_top_level_module() {
509        check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
510    }
511}
512
513pub(crate) fn provide(providers: &mut Providers) {
514    *providers = Providers {
515        check_mod_attrs: |tcx, module_def_id| {
516            // Run both the default checks, and our `#[spirv(...)]` ones.
517            (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
518            check_mod_attrs(tcx, module_def_id);
519        },
520        ..*providers
521    };
522}
523
524// FIXME(eddyb) find something nicer for the error type.
525type ParseAttrError = (Span, String);
526
527#[allow(clippy::get_first)]
528fn parse_attrs_for_checking<'a>(
529    sym: &'a Symbols,
530    attrs: &'a [Attribute],
531) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
532    attrs
533        .iter()
534        .map(move |attr| {
535            // parse the #[rust_gpu::spirv(...)] attr and return the inner list
536            match attr {
537                Attribute::Unparsed(item) => {
538                    // #[...]
539                    let s = &item.path.segments;
540                    if let Some(rust_gpu) = s.get(0) && rust_gpu.name == sym.rust_gpu {
541                        // #[rust_gpu ...]
542                        match s.get(1) {
543                            Some(command) if command.name == sym.spirv_attr_with_version => {
544                                // #[rust_gpu::spirv ...]
545                                if let Some(args) = attr.meta_item_list() {
546                                    // #[rust_gpu::spirv(...)]
547                                    Ok(parse_spirv_attr(sym, args.iter()))
548                                } else {
549                                    // #[rust_gpu::spirv]
550                                    Err((
551                                        attr.span(),
552                                        "#[spirv(..)] attribute must have at least one argument"
553                                            .to_string(),
554                                    ))
555                                }
556                            }
557                            Some(command) if command.name == sym.vector => {
558                                // #[rust_gpu::vector ...]
559                                match s.get(2) {
560                                    // #[rust_gpu::vector::v1]
561                                    Some(version) if version.name == sym.v1 => {
562                                        Ok(SmallVec::from_iter([
563                                            Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
564                                        ]))
565                                    },
566                                    _ => Err((
567                                        attr.span(),
568                                        "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
569                                            .to_string(),
570                                    )),
571                                }
572                            }
573                            _ => {
574                                // #[rust_gpu::...] but not a know version
575                                let spirv = sym.spirv_attr_with_version.as_str();
576                                Err((
577                                    attr.span(),
578                                    format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
579                                Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
580                                ))
581                            }
582                        }
583                    } else {
584                        // #[...] but not #[rust_gpu ...]
585                        Ok(Default::default())
586                    }
587                }
588                Attribute::Parsed(_) => Ok(Default::default()),
589            }
590        })
591        .flat_map(|result| {
592            result
593                .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
594                .into_iter()
595        })
596}
597
598fn parse_spirv_attr<'a>(
599    sym: &Symbols,
600    iter: impl Iterator<Item = &'a MetaItemInner>,
601) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
602    iter.map(|arg| {
603        let span = arg.span();
604        let parsed_attr =
605            if arg.has_name(sym.descriptor_set) {
606                SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
607            } else if arg.has_name(sym.binding) {
608                SpirvAttribute::Binding(parse_attr_int_value(arg)?)
609            } else if arg.has_name(sym.location) {
610                SpirvAttribute::Location(parse_attr_int_value(arg)?)
611            } else if arg.has_name(sym.input_attachment_index) {
612                SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
613            } else if arg.has_name(sym.spec_constant) {
614                SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
615            } else {
616                let name = match arg.ident() {
617                    Some(i) => i,
618                    None => {
619                        return Err((
620                            span,
621                            "#[spirv(..)] attribute argument must be single identifier".to_string(),
622                        ));
623                    }
624                };
625                sym.attributes.get(&name.name).map_or_else(
626                    || Err((name.span, "unknown argument to spirv attribute".to_string())),
627                    |a| {
628                        Ok(match a {
629                            SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
630                                parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
631                            ),
632                            _ => a.clone(),
633                        })
634                    },
635                )?
636            };
637        Ok((span, parsed_attr))
638    })
639    .collect()
640}
641
642fn parse_spec_constant_attr(
643    sym: &Symbols,
644    arg: &MetaItemInner,
645) -> Result<SpecConstant, ParseAttrError> {
646    let mut id = None;
647    let mut default = None;
648
649    if let Some(attrs) = arg.meta_item_list() {
650        for attr in attrs {
651            if attr.has_name(sym.id) {
652                if id.is_none() {
653                    id = Some(parse_attr_int_value(attr)?);
654                } else {
655                    return Err((attr.span(), "`id` may only be specified once".into()));
656                }
657            } else if attr.has_name(sym.default) {
658                if default.is_none() {
659                    default = Some(parse_attr_int_value(attr)?);
660                } else {
661                    return Err((attr.span(), "`default` may only be specified once".into()));
662                }
663            } else {
664                return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
665            }
666        }
667    }
668    Ok(SpecConstant {
669        id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
670        default,
671        // to be set later
672        array_count: None,
673    })
674}
675
676fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
677    let arg = match arg.meta_item() {
678        Some(arg) => arg,
679        None => return Err((arg.span(), "attribute must have value".to_string())),
680    };
681    match arg.name_value_literal() {
682        Some(&MetaItemLit {
683            kind: LitKind::Int(x, ..),
684            ..
685        }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
686        _ => Err((arg.span, "attribute value must be integer".to_string())),
687    }
688}
689
690fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
691    let arg = match arg.meta_item() {
692        Some(arg) => arg,
693        None => return Err((arg.span(), "attribute must have value".to_string())),
694    };
695    match arg.meta_item_list() {
696        Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
697            let mut local_size = [1; 3];
698            for (idx, lit) in tuple.iter().enumerate() {
699                match lit {
700                    MetaItemInner::Lit(MetaItemLit {
701                                           kind: LitKind::Int(x, ..),
702                                           ..
703                                       }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
704                    _ => return Err((lit.span(), "must be a u32 literal".to_string())),
705                }
706            }
707            Ok(local_size)
708        }
709        Some([]) => Err((
710            arg.span,
711            "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
712        )),
713        Some(tuple) if tuple.len() > 3 => Err((
714            arg.span,
715            "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
716        )),
717        _ => Err((
718            arg.span,
719            "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
720        )),
721    }
722}
723
724// for a given entry, gather up the additional attributes
725// in this case ExecutionMode's, some have extra arguments
726// others are specified with x, y, or z components
727// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
728fn parse_entry_attrs(
729    sym: &Symbols,
730    arg: &MetaItemInner,
731    name: &Ident,
732    execution_model: ExecutionModel,
733) -> Result<Entry, ParseAttrError> {
734    use ExecutionMode::*;
735    use ExecutionModel::*;
736    let mut entry = Entry::from(execution_model);
737    let mut origin_mode: Option<ExecutionMode> = None;
738    let mut local_size: Option<[u32; 3]> = None;
739    let mut local_size_hint: Option<[u32; 3]> = None;
740    // Reserved
741    //let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
742    if let Some(attrs) = arg.meta_item_list() {
743        for attr in attrs {
744            if let Some(attr_name) = attr.ident() {
745                if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
746                {
747                    use crate::symbols::ExecutionModeExtraDim::*;
748                    let val = match extra_dim {
749                        None | Tuple => Option::None,
750                        _ => Some(parse_attr_int_value(attr)?),
751                    };
752                    match execution_mode {
753                        OriginUpperLeft | OriginLowerLeft => {
754                            origin_mode.replace(*execution_mode);
755                        }
756                        LocalSize => {
757                            if local_size.is_none() {
758                                local_size.replace(parse_local_size_attr(attr)?);
759                            } else {
760                                return Err((
761                                    attr_name.span,
762                                    String::from(
763                                        "`#[spirv(compute(threads))]` may only be specified once",
764                                    ),
765                                ));
766                            }
767                        }
768                        LocalSizeHint => {
769                            let val = val.unwrap();
770                            if local_size_hint.is_none() {
771                                local_size_hint.replace([1, 1, 1]);
772                            }
773                            let local_size_hint = local_size_hint.as_mut().unwrap();
774                            match extra_dim {
775                                X => {
776                                    local_size_hint[0] = val;
777                                }
778                                Y => {
779                                    local_size_hint[1] = val;
780                                }
781                                Z => {
782                                    local_size_hint[2] = val;
783                                }
784                                _ => unreachable!(),
785                            }
786                        }
787                        // Reserved
788                        /*MaxWorkgroupSizeINTEL => {
789                            let val = val.unwrap();
790                            if max_workgroup_size_intel.is_none() {
791                                max_workgroup_size_intel.replace([1, 1, 1]);
792                            }
793                            let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
794                                .unwrap();
795                            match extra_dim {
796                                X => {
797                                    max_workgroup_size_intel[0] = val;
798                                },
799                                Y => {
800                                    max_workgroup_size_intel[1] = val;
801                                },
802                                Z => {
803                                    max_workgroup_size_intel[2] = val;
804                                },
805                                _ => unreachable!(),
806                            }
807                        },*/
808                        _ => {
809                            if let Some(val) = val {
810                                entry
811                                    .execution_modes
812                                    .push((*execution_mode, ExecutionModeExtra::new([val])));
813                            } else {
814                                entry
815                                    .execution_modes
816                                    .push((*execution_mode, ExecutionModeExtra::new([])));
817                            }
818                        }
819                    }
820                } else if attr_name.name == sym.entry_point_name {
821                    match attr.value_str() {
822                        Some(sym) => {
823                            entry.name = Some(sym);
824                        }
825                        None => {
826                            return Err((
827                                attr_name.span,
828                                format!(
829                                    "#[spirv({name}(..))] unknown attribute argument {attr_name}"
830                                ),
831                            ));
832                        }
833                    }
834                } else {
835                    return Err((
836                        attr_name.span,
837                        format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
838                    ));
839                }
840            } else {
841                return Err((
842                    arg.span(),
843                    format!("#[spirv({name}(..))] attribute argument must be single identifier"),
844                ));
845            }
846        }
847    }
848    match entry.execution_model {
849        Fragment => {
850            let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
851            entry
852                .execution_modes
853                .push((origin_mode, ExecutionModeExtra::new([])));
854        }
855        GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
856            if let Some(local_size) = local_size {
857                entry
858                    .execution_modes
859                    .push((LocalSize, ExecutionModeExtra::new(local_size)));
860            } else {
861                return Err((
862                    arg.span(),
863                    String::from(
864                        "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
865                    ),
866                ));
867            }
868        }
869        //TODO: Cover more defaults
870        _ => {}
871    }
872    Ok(entry)
873}