Skip to main content

rustc_codegen_spirv/
attr.rs

1//! `#[spirv(...)]` attribute support.
2//!
3//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
4
5use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
21#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23    args: [u32; 3],
24    len: u8,
25}
26
27impl ExecutionModeExtra {
28    pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29        let _args = args.as_ref();
30        let mut args = [0; 3];
31        args[.._args.len()].copy_from_slice(_args);
32        let len = _args.len() as u8;
33        Self { args, len }
34    }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38    fn as_ref(&self) -> &[u32] {
39        &self.args[..self.len as _]
40    }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45    pub execution_model: ExecutionModel,
46    pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47    pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51    fn from(execution_model: ExecutionModel) -> Self {
52        Self {
53            execution_model,
54            execution_modes: Vec::new(),
55            name: None,
56        }
57    }
58}
59
60/// `struct` types that are used to represent special SPIR-V types.
61#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63    GenericImageType,
64    Sampler,
65    AccelerationStructureKhr,
66    SampledImage,
67    RayQueryKhr,
68    RuntimeArray,
69    TypedBuffer,
70    Matrix,
71    Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76    pub id: u32,
77    pub default: Option<u32>,
78    pub array_count: Option<u32>,
79}
80
81// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
82// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
83#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85    // `struct` attributes:
86    IntrinsicType(IntrinsicType),
87    Block,
88
89    // `fn` attributes:
90    Entry(Entry),
91
92    // (entry) `fn` parameter attributes:
93    StorageClass(StorageClass),
94    Builtin(BuiltIn),
95    DescriptorSet(u32),
96    Binding(u32),
97    Location(u32),
98    Flat,
99    PerPrimitiveExt,
100    Invariant,
101    InputAttachmentIndex(u32),
102    SpecConstant(SpecConstant),
103
104    // `fn`/closure attributes:
105    BufferLoadIntrinsic,
106    BufferStoreIntrinsic,
107}
108
109// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
110// field name instead of `node` (which feels inadequate in this context).
111#[derive(Copy, Clone)]
112pub struct Spanned<T> {
113    pub value: T,
114    pub span: Span,
115}
116
117/// Condensed version of a `SpirvAttribute` list, but only keeping one value per
118/// variant of `SpirvAttribute`, and treating multiple such attributes an error.
119// FIXME(eddyb) should this and `fn try_insert_attr` below be generated by a macro?
120#[derive(Default)]
121pub struct AggregatedSpirvAttributes {
122    // `struct` attributes:
123    pub intrinsic_type: Option<Spanned<IntrinsicType>>,
124    pub block: Option<Spanned<()>>,
125
126    // `fn` attributes:
127    pub entry: Option<Spanned<Entry>>,
128
129    // (entry) `fn` parameter attributes:
130    pub storage_class: Option<Spanned<StorageClass>>,
131    pub builtin: Option<Spanned<BuiltIn>>,
132    pub descriptor_set: Option<Spanned<u32>>,
133    pub binding: Option<Spanned<u32>>,
134    pub location: Option<Spanned<u32>>,
135    pub flat: Option<Spanned<()>>,
136    pub invariant: Option<Spanned<()>>,
137    pub per_primitive_ext: Option<Spanned<()>>,
138    pub input_attachment_index: Option<Spanned<u32>>,
139    pub spec_constant: Option<Spanned<SpecConstant>>,
140
141    // `fn`/closure attributes:
142    pub buffer_load_intrinsic: Option<Spanned<()>>,
143    pub buffer_store_intrinsic: Option<Spanned<()>>,
144}
145
146struct MultipleAttrs {
147    prev_span: Span,
148    category: &'static str,
149}
150
151impl AggregatedSpirvAttributes {
152    /// Compute `AggregatedSpirvAttributes` for use during codegen.
153    ///
154    /// Any errors for malformed/duplicate attributes will have been reported
155    /// prior to codegen, by the `attr` check pass.
156    pub fn parse<'tcx>(
157        cx: &CodegenCx<'tcx>,
158        attrs: impl IntoIterator<Item = &'tcx Attribute>,
159    ) -> Self {
160        let mut aggregated_attrs = Self::default();
161
162        // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
163        // to see an attribute error, it will cause an ICE instead.
164        for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
165            let (span, parsed_attr) = match parse_attr_result {
166                Ok(span_and_parsed_attr) => span_and_parsed_attr,
167                Err((span, msg)) => {
168                    cx.tcx.dcx().span_delayed_bug(span, msg);
169                    continue;
170                }
171            };
172            match aggregated_attrs.try_insert_attr(parsed_attr, span) {
173                Ok(()) => {}
174                Err(MultipleAttrs {
175                    prev_span: _,
176                    category,
177                }) => {
178                    cx.tcx
179                        .dcx()
180                        .span_delayed_bug(span, format!("multiple {category} attributes"));
181                }
182            }
183        }
184
185        aggregated_attrs
186    }
187
188    fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
189        fn try_insert<T>(
190            slot: &mut Option<Spanned<T>>,
191            value: T,
192            span: Span,
193            category: &'static str,
194        ) -> Result<(), MultipleAttrs> {
195            if let Some(prev) = slot {
196                Err(MultipleAttrs {
197                    prev_span: prev.span,
198                    category,
199                })
200            } else {
201                *slot = Some(Spanned { value, span });
202                Ok(())
203            }
204        }
205
206        use SpirvAttribute::*;
207        match attr {
208            IntrinsicType(value) => {
209                try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
210            }
211            Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
212            Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
213            StorageClass(value) => {
214                try_insert(&mut self.storage_class, value, span, "storage class")
215            }
216            Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
217            DescriptorSet(value) => try_insert(
218                &mut self.descriptor_set,
219                value,
220                span,
221                "#[spirv(descriptor_set)]",
222            ),
223            Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
224            Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
225            Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
226            Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
227            PerPrimitiveExt => try_insert(
228                &mut self.per_primitive_ext,
229                (),
230                span,
231                "#[spirv(per_primitive_ext)]",
232            ),
233            InputAttachmentIndex(value) => try_insert(
234                &mut self.input_attachment_index,
235                value,
236                span,
237                "#[spirv(attachment_index)]",
238            ),
239            SpecConstant(value) => try_insert(
240                &mut self.spec_constant,
241                value,
242                span,
243                "#[spirv(spec_constant)]",
244            ),
245            BufferLoadIntrinsic => try_insert(
246                &mut self.buffer_load_intrinsic,
247                (),
248                span,
249                "#[spirv(buffer_load_intrinsic)]",
250            ),
251            BufferStoreIntrinsic => try_insert(
252                &mut self.buffer_store_intrinsic,
253                (),
254                span,
255                "#[spirv(buffer_store_intrinsic)]",
256            ),
257        }
258    }
259}
260
261// FIXME(eddyb) make this reusable from somewhere in `rustc`.
262fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
263    match impl_item.kind {
264        hir::ImplItemKind::Const(..) => Target::AssocConst,
265        hir::ImplItemKind::Fn(..) => {
266            let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
267            let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
268            let containing_impl_is_for_trait = match &containing_item.kind {
269                hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
270                _ => unreachable!("parent of an ImplItem must be an Impl"),
271            };
272            if containing_impl_is_for_trait {
273                Target::Method(MethodKind::Trait { body: true })
274            } else {
275                Target::Method(MethodKind::Inherent)
276            }
277        }
278        hir::ImplItemKind::Type(..) => Target::AssocTy,
279    }
280}
281
282struct CheckSpirvAttrVisitor<'tcx> {
283    tcx: TyCtxt<'tcx>,
284    sym: Rc<Symbols>,
285}
286
287impl CheckSpirvAttrVisitor<'_> {
288    fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
289        let mut aggregated_attrs = AggregatedSpirvAttributes::default();
290
291        let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
292
293        let attrs = self.tcx.hir_attrs(hir_id);
294        for parse_attr_result in parse_attrs(attrs) {
295            let (span, parsed_attr) = match parse_attr_result {
296                Ok(span_and_parsed_attr) => span_and_parsed_attr,
297                Err((span, msg)) => {
298                    self.tcx.dcx().span_err(span, msg);
299                    continue;
300                }
301            };
302
303            /// Error newtype marker used below for readability.
304            struct Expected<T>(T);
305
306            let valid_target = match parsed_attr {
307                SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
308                    Target::Struct => {
309                        // FIXME(eddyb) further check type attribute validity,
310                        // e.g. layout, generics, other attributes, etc.
311                        Ok(())
312                    }
313
314                    _ => Err(Expected("struct")),
315                },
316
317                SpirvAttribute::Entry(_) => match target {
318                    Target::Fn
319                    | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
320                        // FIXME(eddyb) further check entry-point attribute validity,
321                        // e.g. signature, shouldn't have `#[inline]` or generics, etc.
322                        Ok(())
323                    }
324
325                    _ => Err(Expected("function")),
326                },
327
328                SpirvAttribute::StorageClass(_)
329                | SpirvAttribute::Builtin(_)
330                | SpirvAttribute::DescriptorSet(_)
331                | SpirvAttribute::Binding(_)
332                | SpirvAttribute::Location(_)
333                | SpirvAttribute::Flat
334                | SpirvAttribute::Invariant
335                | SpirvAttribute::PerPrimitiveExt
336                | SpirvAttribute::InputAttachmentIndex(_)
337                | SpirvAttribute::SpecConstant(_) => match target {
338                    Target::Param => {
339                        let parent_hir_id = self.tcx.parent_hir_id(hir_id);
340                        let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
341                            .filter_map(|r| r.ok())
342                            .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
343                        if !parent_is_entry_point {
344                            self.tcx.dcx().span_err(
345                                span,
346                                "attribute is only valid on a parameter of an entry-point function",
347                            );
348                        } else {
349                            // FIXME(eddyb) should we just remove all 5 of these storage class
350                            // attributes, instead of disallowing them here?
351                            if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
352                                let valid = match storage_class {
353                                    StorageClass::Input | StorageClass::Output => {
354                                        Err("is the default and should not be explicitly specified")
355                                    }
356
357                                    StorageClass::Private
358                                    | StorageClass::Function
359                                    | StorageClass::Generic => {
360                                        Err("can not be used as part of an entry's interface")
361                                    }
362
363                                    _ => Ok(()),
364                                };
365
366                                if let Err(msg) = valid {
367                                    self.tcx.dcx().span_err(
368                                        span,
369                                        format!("`{storage_class:?}` storage class {msg}"),
370                                    );
371                                }
372                            }
373                        }
374                        Ok(())
375                    }
376
377                    _ => Err(Expected("function parameter")),
378                },
379                SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
380                    match target {
381                        Target::Fn => Ok(()),
382                        _ => Err(Expected("function")),
383                    }
384                }
385            };
386            match valid_target {
387                Err(Expected(expected_target)) => {
388                    self.tcx.dcx().span_err(
389                        span,
390                        format!(
391                            "attribute is only valid on a {expected_target}, not on a {target}"
392                        ),
393                    );
394                }
395                Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
396                    Ok(()) => {}
397                    Err(MultipleAttrs {
398                        prev_span,
399                        category,
400                    }) => {
401                        self.tcx
402                            .dcx()
403                            .struct_span_err(
404                                span,
405                                format!("only one {category} attribute is allowed on a {target}"),
406                            )
407                            .with_span_note(prev_span, format!("previous {category} attribute"))
408                            .emit();
409                    }
410                },
411            }
412        }
413
414        // At this point we have all of the attributes (valid for this target),
415        // so we can perform further checks, emit warnings, etc.
416
417        if let Some(block_attr) = aggregated_attrs.block {
418            self.tcx.dcx().span_warn(
419                block_attr.span,
420                "#[spirv(block)] is no longer needed and should be removed",
421            );
422        }
423    }
424}
425
426// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
427impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
428    type NestedFilter = nested_filter::OnlyBodies;
429
430    fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
431        self.tcx
432    }
433
434    fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
435        let target = Target::from_item(item);
436        self.check_spirv_attributes(item.hir_id(), target);
437        intravisit::walk_item(self, item);
438    }
439
440    fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
441        let target = Target::from_generic_param(generic_param);
442        self.check_spirv_attributes(generic_param.hir_id, target);
443        intravisit::walk_generic_param(self, generic_param);
444    }
445
446    fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
447        let target = Target::from_trait_item(trait_item);
448        self.check_spirv_attributes(trait_item.hir_id(), target);
449        intravisit::walk_trait_item(self, trait_item);
450    }
451
452    fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
453        self.check_spirv_attributes(field.hir_id, Target::Field);
454        intravisit::walk_field_def(self, field);
455    }
456
457    fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
458        self.check_spirv_attributes(arm.hir_id, Target::Arm);
459        intravisit::walk_arm(self, arm);
460    }
461
462    fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
463        let target = Target::from_foreign_item(f_item);
464        self.check_spirv_attributes(f_item.hir_id(), target);
465        intravisit::walk_foreign_item(self, f_item);
466    }
467
468    fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
469        let target = target_from_impl_item(self.tcx, impl_item);
470        self.check_spirv_attributes(impl_item.hir_id(), target);
471        intravisit::walk_impl_item(self, impl_item);
472    }
473
474    fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
475        // When checking statements ignore expressions, they will be checked later.
476        if let hir::StmtKind::Let(l) = stmt.kind {
477            self.check_spirv_attributes(l.hir_id, Target::Statement);
478        }
479        intravisit::walk_stmt(self, stmt);
480    }
481
482    fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
483        let target = match expr.kind {
484            hir::ExprKind::Closure { .. } => Target::Closure,
485            _ => Target::Expression,
486        };
487
488        self.check_spirv_attributes(expr.hir_id, target);
489        intravisit::walk_expr(self, expr);
490    }
491
492    fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
493        self.check_spirv_attributes(variant.hir_id, Target::Variant);
494        intravisit::walk_variant(self, variant);
495    }
496
497    fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
498        self.check_spirv_attributes(param.hir_id, Target::Param);
499
500        intravisit::walk_param(self, param);
501    }
502}
503
504// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
505fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
506    let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
507        tcx,
508        sym: Symbols::get(),
509    };
510    tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
511    if module_def_id.is_top_level_module() {
512        check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
513    }
514}
515
516pub(crate) fn provide(providers: &mut Providers) {
517    *providers = Providers {
518        check_mod_attrs: |tcx, module_def_id| {
519            // Run both the default checks, and our `#[spirv(...)]` ones.
520            (rustc_interface::DEFAULT_QUERY_PROVIDERS
521                .queries
522                .check_mod_attrs)(tcx, module_def_id);
523            check_mod_attrs(tcx, module_def_id);
524        },
525        // HACK(LegNeato) keep `#[spirv(...)]` entry-points from getting
526        // internalized and DCE'd before codegen. See issue #590.
527        cross_crate_inlinable: |tcx, def_id| {
528            let sym = Symbols::get();
529            let path = [sym.rust_gpu, sym.spirv_attr_with_version];
530            let attrs: Vec<_> = tcx.get_attrs_by_path(def_id.to_def_id(), &path).collect();
531            let is_entry_point = parse_attrs_for_checking(&sym, attrs)
532                .any(|result| matches!(result, Ok((_, SpirvAttribute::Entry(_)))));
533            if is_entry_point {
534                false
535            } else {
536                (rustc_interface::DEFAULT_QUERY_PROVIDERS
537                    .queries
538                    .cross_crate_inlinable)(tcx, def_id)
539            }
540        },
541        ..*providers
542    };
543}
544
545// FIXME(eddyb) find something nicer for the error type.
546type ParseAttrError = (Span, String);
547
548#[allow(clippy::get_first)]
549fn parse_attrs_for_checking<'sym, 'attr, I>(
550    sym: &'sym Symbols,
551    attrs: I,
552) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'sym
553where
554    I: IntoIterator<Item = &'attr Attribute> + 'sym,
555    I::IntoIter: 'sym,
556    'attr: 'sym,
557{
558    attrs
559        .into_iter()
560        .map(move |attr| {
561            // parse the #[rust_gpu::spirv(...)] attr and return the inner list
562            match attr {
563                Attribute::Unparsed(item) => {
564                    // #[...]
565                    let s = &item.path.segments;
566                    if let Some(rust_gpu) = s.get(0) && *rust_gpu == sym.rust_gpu {
567                        // #[rust_gpu ...]
568                        match s.get(1) {
569                            Some(command) if *command == sym.spirv_attr_with_version => {
570                                // #[rust_gpu::spirv ...]
571                                if let Some(args) = attr.meta_item_list() {
572                                    // #[rust_gpu::spirv(...)]
573                                    Ok(parse_spirv_attr(sym, args.iter()))
574                                } else {
575                                    // #[rust_gpu::spirv]
576                                    Err((
577                                        attr.span(),
578                                        "#[spirv(..)] attribute must have at least one argument"
579                                            .to_string(),
580                                    ))
581                                }
582                            }
583                            Some(command) if *command == sym.vector => {
584                                // #[rust_gpu::vector ...]
585                                match s.get(2) {
586                                    // #[rust_gpu::vector::v1]
587                                    Some(version) if *version == sym.v1 => {
588                                        Ok(SmallVec::from_iter([
589                                            Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
590                                        ]))
591                                    },
592                                    _ => Err((
593                                        attr.span(),
594                                        "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
595                                            .to_string(),
596                                    )),
597                                }
598                            }
599                            _ => {
600                                // #[rust_gpu::...] but not a know version
601                                let spirv = sym.spirv_attr_with_version.as_str();
602                                Err((
603                                    attr.span(),
604                                    format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
605                                Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
606                                ))
607                            }
608                        }
609                    } else {
610                        // #[...] but not #[rust_gpu ...]
611                        Ok(Default::default())
612                    }
613                }
614                Attribute::Parsed(_) => Ok(Default::default()),
615            }
616        })
617        .flat_map(|result| {
618            result
619                .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
620                .into_iter()
621        })
622}
623
624fn parse_spirv_attr<'a>(
625    sym: &Symbols,
626    iter: impl Iterator<Item = &'a MetaItemInner>,
627) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
628    iter.map(|arg| {
629        let span = arg.span();
630        let parsed_attr =
631            if arg.has_name(sym.descriptor_set) {
632                SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
633            } else if arg.has_name(sym.binding) {
634                SpirvAttribute::Binding(parse_attr_int_value(arg)?)
635            } else if arg.has_name(sym.location) {
636                SpirvAttribute::Location(parse_attr_int_value(arg)?)
637            } else if arg.has_name(sym.input_attachment_index) {
638                SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
639            } else if arg.has_name(sym.spec_constant) {
640                SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
641            } else {
642                let name = match arg.ident() {
643                    Some(i) => i,
644                    None => {
645                        return Err((
646                            span,
647                            "#[spirv(..)] attribute argument must be single identifier".to_string(),
648                        ));
649                    }
650                };
651                sym.attributes.get(&name.name).map_or_else(
652                    || Err((name.span, "unknown argument to spirv attribute".to_string())),
653                    |a| {
654                        Ok(match a {
655                            SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
656                                parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
657                            ),
658                            _ => a.clone(),
659                        })
660                    },
661                )?
662            };
663        Ok((span, parsed_attr))
664    })
665    .collect()
666}
667
668fn parse_spec_constant_attr(
669    sym: &Symbols,
670    arg: &MetaItemInner,
671) -> Result<SpecConstant, ParseAttrError> {
672    let mut id = None;
673    let mut default = None;
674
675    if let Some(attrs) = arg.meta_item_list() {
676        for attr in attrs {
677            if attr.has_name(sym.id) {
678                if id.is_none() {
679                    id = Some(parse_attr_int_value(attr)?);
680                } else {
681                    return Err((attr.span(), "`id` may only be specified once".into()));
682                }
683            } else if attr.has_name(sym.default) {
684                if default.is_none() {
685                    default = Some(parse_attr_int_value(attr)?);
686                } else {
687                    return Err((attr.span(), "`default` may only be specified once".into()));
688                }
689            } else {
690                return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
691            }
692        }
693    }
694    Ok(SpecConstant {
695        id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
696        default,
697        // to be set later
698        array_count: None,
699    })
700}
701
702fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
703    let arg = match arg.meta_item() {
704        Some(arg) => arg,
705        None => return Err((arg.span(), "attribute must have value".to_string())),
706    };
707    match arg.name_value_literal() {
708        Some(&MetaItemLit {
709            kind: LitKind::Int(x, ..),
710            ..
711        }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
712        _ => Err((arg.span, "attribute value must be integer".to_string())),
713    }
714}
715
716fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
717    let arg = match arg.meta_item() {
718        Some(arg) => arg,
719        None => return Err((arg.span(), "attribute must have value".to_string())),
720    };
721    match arg.meta_item_list() {
722        Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
723            let mut local_size = [1; 3];
724            for (idx, lit) in tuple.iter().enumerate() {
725                match lit {
726                    MetaItemInner::Lit(MetaItemLit {
727                                           kind: LitKind::Int(x, ..),
728                                           ..
729                                       }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
730                    _ => return Err((lit.span(), "must be a u32 literal".to_string())),
731                }
732            }
733            Ok(local_size)
734        }
735        Some([]) => Err((
736            arg.span,
737            "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
738        )),
739        Some(tuple) if tuple.len() > 3 => Err((
740            arg.span,
741            "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
742        )),
743        _ => Err((
744            arg.span,
745            "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
746        )),
747    }
748}
749
750// for a given entry, gather up the additional attributes
751// in this case ExecutionMode's, some have extra arguments
752// others are specified with x, y, or z components
753// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
754fn parse_entry_attrs(
755    sym: &Symbols,
756    arg: &MetaItemInner,
757    name: &Ident,
758    execution_model: ExecutionModel,
759) -> Result<Entry, ParseAttrError> {
760    use ExecutionMode::*;
761    use ExecutionModel::*;
762    let mut entry = Entry::from(execution_model);
763    let mut origin_mode: Option<ExecutionMode> = None;
764    let mut local_size: Option<[u32; 3]> = None;
765    let mut local_size_hint: Option<[u32; 3]> = None;
766    // Reserved
767    //let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
768    if let Some(attrs) = arg.meta_item_list() {
769        for attr in attrs {
770            if let Some(attr_name) = attr.ident() {
771                if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
772                {
773                    use crate::symbols::ExecutionModeExtraDim::*;
774                    let val = match extra_dim {
775                        None | Tuple => Option::None,
776                        _ => Some(parse_attr_int_value(attr)?),
777                    };
778                    match execution_mode {
779                        OriginUpperLeft | OriginLowerLeft => {
780                            origin_mode.replace(*execution_mode);
781                        }
782                        LocalSize => {
783                            if local_size.is_none() {
784                                local_size.replace(parse_local_size_attr(attr)?);
785                            } else {
786                                return Err((
787                                    attr_name.span,
788                                    String::from(
789                                        "`#[spirv(compute(threads))]` may only be specified once",
790                                    ),
791                                ));
792                            }
793                        }
794                        LocalSizeHint => {
795                            let val = val.unwrap();
796                            if local_size_hint.is_none() {
797                                local_size_hint.replace([1, 1, 1]);
798                            }
799                            let local_size_hint = local_size_hint.as_mut().unwrap();
800                            match extra_dim {
801                                X => {
802                                    local_size_hint[0] = val;
803                                }
804                                Y => {
805                                    local_size_hint[1] = val;
806                                }
807                                Z => {
808                                    local_size_hint[2] = val;
809                                }
810                                _ => unreachable!(),
811                            }
812                        }
813                        // Reserved
814                        /*MaxWorkgroupSizeINTEL => {
815                            let val = val.unwrap();
816                            if max_workgroup_size_intel.is_none() {
817                                max_workgroup_size_intel.replace([1, 1, 1]);
818                            }
819                            let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
820                                .unwrap();
821                            match extra_dim {
822                                X => {
823                                    max_workgroup_size_intel[0] = val;
824                                },
825                                Y => {
826                                    max_workgroup_size_intel[1] = val;
827                                },
828                                Z => {
829                                    max_workgroup_size_intel[2] = val;
830                                },
831                                _ => unreachable!(),
832                            }
833                        },*/
834                        _ => {
835                            if let Some(val) = val {
836                                entry
837                                    .execution_modes
838                                    .push((*execution_mode, ExecutionModeExtra::new([val])));
839                            } else {
840                                entry
841                                    .execution_modes
842                                    .push((*execution_mode, ExecutionModeExtra::new([])));
843                            }
844                        }
845                    }
846                } else if attr_name.name == sym.entry_point_name {
847                    match attr.value_str() {
848                        Some(sym) => {
849                            entry.name = Some(sym);
850                        }
851                        None => {
852                            return Err((
853                                attr_name.span,
854                                format!(
855                                    "#[spirv({name}(..))] unknown attribute argument {attr_name}"
856                                ),
857                            ));
858                        }
859                    }
860                } else {
861                    return Err((
862                        attr_name.span,
863                        format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
864                    ));
865                }
866            } else {
867                return Err((
868                    arg.span(),
869                    format!("#[spirv({name}(..))] attribute argument must be single identifier"),
870                ));
871            }
872        }
873    }
874    match entry.execution_model {
875        Fragment => {
876            let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
877            entry
878                .execution_modes
879                .push((origin_mode, ExecutionModeExtra::new([])));
880        }
881        GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
882            if let Some(local_size) = local_size {
883                entry
884                    .execution_modes
885                    .push((LocalSize, ExecutionModeExtra::new(local_size)));
886            } else {
887                return Err((
888                    arg.span(),
889                    String::from(
890                        "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
891                    ),
892                ));
893            }
894        }
895        //TODO: Cover more defaults
896        _ => {}
897    }
898    Ok(entry)
899}