Skip to main content

rustc_codegen_spirv/
attr.rs

1//! `#[spirv(...)]` attribute support.
2//!
3//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
4
5use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
21#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23    args: [u32; 3],
24    len: u8,
25}
26
27impl ExecutionModeExtra {
28    pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29        let _args = args.as_ref();
30        let mut args = [0; 3];
31        args[.._args.len()].copy_from_slice(_args);
32        let len = _args.len() as u8;
33        Self { args, len }
34    }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38    fn as_ref(&self) -> &[u32] {
39        &self.args[..self.len as _]
40    }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45    pub execution_model: ExecutionModel,
46    pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47    pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51    fn from(execution_model: ExecutionModel) -> Self {
52        Self {
53            execution_model,
54            execution_modes: Vec::new(),
55            name: None,
56        }
57    }
58}
59
60/// `struct` types that are used to represent special SPIR-V types.
61#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63    GenericImageType,
64    Sampler,
65    AccelerationStructureKhr,
66    SampledImage,
67    RayQueryKhr,
68    RuntimeArray,
69    TypedBuffer,
70    Matrix,
71    Vector,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
75pub struct SpecConstant {
76    pub id: u32,
77    pub default: Option<u32>,
78    pub array_count: Option<u32>,
79}
80
81// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
82// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
83#[derive(Debug, Clone)]
84pub enum SpirvAttribute {
85    // `struct` attributes:
86    IntrinsicType(IntrinsicType),
87    Block,
88
89    // `fn` attributes:
90    Entry(Entry),
91
92    // (entry) `fn` parameter attributes:
93    StorageClass(StorageClass),
94    Builtin(BuiltIn),
95    DescriptorSet(u32),
96    Binding(u32),
97    Location(u32),
98    Flat,
99    PerPrimitiveExt,
100    Invariant,
101    InputAttachmentIndex(u32),
102    SpecConstant(SpecConstant),
103
104    // `fn`/closure attributes:
105    BufferLoadIntrinsic,
106    BufferStoreIntrinsic,
107}
108
109// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
110// field name instead of `node` (which feels inadequate in this context).
111#[derive(Copy, Clone)]
112pub struct Spanned<T> {
113    pub value: T,
114    pub span: Span,
115}
116
117/// Condensed version of a `SpirvAttribute` list, but only keeping one value per
118/// variant of `SpirvAttribute`, and treating multiple such attributes an error.
119// FIXME(eddyb) should this and `fn try_insert_attr` below be generated by a macro?
120#[derive(Default)]
121pub struct AggregatedSpirvAttributes {
122    // `struct` attributes:
123    pub intrinsic_type: Option<Spanned<IntrinsicType>>,
124    pub block: Option<Spanned<()>>,
125
126    // `fn` attributes:
127    pub entry: Option<Spanned<Entry>>,
128
129    // (entry) `fn` parameter attributes:
130    pub storage_class: Option<Spanned<StorageClass>>,
131    pub builtin: Option<Spanned<BuiltIn>>,
132    pub descriptor_set: Option<Spanned<u32>>,
133    pub binding: Option<Spanned<u32>>,
134    pub location: Option<Spanned<u32>>,
135    pub flat: Option<Spanned<()>>,
136    pub invariant: Option<Spanned<()>>,
137    pub per_primitive_ext: Option<Spanned<()>>,
138    pub input_attachment_index: Option<Spanned<u32>>,
139    pub spec_constant: Option<Spanned<SpecConstant>>,
140
141    // `fn`/closure attributes:
142    pub buffer_load_intrinsic: Option<Spanned<()>>,
143    pub buffer_store_intrinsic: Option<Spanned<()>>,
144}
145
146struct MultipleAttrs {
147    prev_span: Span,
148    category: &'static str,
149}
150
151impl AggregatedSpirvAttributes {
152    /// Compute `AggregatedSpirvAttributes` for use during codegen.
153    ///
154    /// Any errors for malformed/duplicate attributes will have been reported
155    /// prior to codegen, by the `attr` check pass.
156    pub fn parse<'tcx>(
157        cx: &CodegenCx<'tcx>,
158        attrs: impl IntoIterator<Item = &'tcx Attribute>,
159    ) -> Self {
160        let mut aggregated_attrs = Self::default();
161
162        // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
163        // to see an attribute error, it will cause an ICE instead.
164        for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
165            let (span, parsed_attr) = match parse_attr_result {
166                Ok(span_and_parsed_attr) => span_and_parsed_attr,
167                Err((span, msg)) => {
168                    cx.tcx.dcx().span_delayed_bug(span, msg);
169                    continue;
170                }
171            };
172            match aggregated_attrs.try_insert_attr(parsed_attr, span) {
173                Ok(()) => {}
174                Err(MultipleAttrs {
175                    prev_span: _,
176                    category,
177                }) => {
178                    cx.tcx
179                        .dcx()
180                        .span_delayed_bug(span, format!("multiple {category} attributes"));
181                }
182            }
183        }
184
185        aggregated_attrs
186    }
187
188    fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
189        fn try_insert<T>(
190            slot: &mut Option<Spanned<T>>,
191            value: T,
192            span: Span,
193            category: &'static str,
194        ) -> Result<(), MultipleAttrs> {
195            if let Some(prev) = slot {
196                Err(MultipleAttrs {
197                    prev_span: prev.span,
198                    category,
199                })
200            } else {
201                *slot = Some(Spanned { value, span });
202                Ok(())
203            }
204        }
205
206        use SpirvAttribute::*;
207        match attr {
208            IntrinsicType(value) => {
209                try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
210            }
211            Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
212            Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
213            StorageClass(value) => {
214                try_insert(&mut self.storage_class, value, span, "storage class")
215            }
216            Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
217            DescriptorSet(value) => try_insert(
218                &mut self.descriptor_set,
219                value,
220                span,
221                "#[spirv(descriptor_set)]",
222            ),
223            Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
224            Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
225            Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
226            Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
227            PerPrimitiveExt => try_insert(
228                &mut self.per_primitive_ext,
229                (),
230                span,
231                "#[spirv(per_primitive_ext)]",
232            ),
233            InputAttachmentIndex(value) => try_insert(
234                &mut self.input_attachment_index,
235                value,
236                span,
237                "#[spirv(attachment_index)]",
238            ),
239            SpecConstant(value) => try_insert(
240                &mut self.spec_constant,
241                value,
242                span,
243                "#[spirv(spec_constant)]",
244            ),
245            BufferLoadIntrinsic => try_insert(
246                &mut self.buffer_load_intrinsic,
247                (),
248                span,
249                "#[spirv(buffer_load_intrinsic)]",
250            ),
251            BufferStoreIntrinsic => try_insert(
252                &mut self.buffer_store_intrinsic,
253                (),
254                span,
255                "#[spirv(buffer_store_intrinsic)]",
256            ),
257        }
258    }
259}
260
261// FIXME(eddyb) make this reusable from somewhere in `rustc`.
262fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
263    match impl_item.kind {
264        hir::ImplItemKind::Const(..) => Target::AssocConst,
265        hir::ImplItemKind::Fn(..) => {
266            let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
267            let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
268            let containing_impl_is_for_trait = match &containing_item.kind {
269                hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
270                _ => unreachable!("parent of an ImplItem must be an Impl"),
271            };
272            if containing_impl_is_for_trait {
273                Target::Method(MethodKind::Trait { body: true })
274            } else {
275                Target::Method(MethodKind::Inherent)
276            }
277        }
278        hir::ImplItemKind::Type(..) => Target::AssocTy,
279    }
280}
281
282struct CheckSpirvAttrVisitor<'tcx> {
283    tcx: TyCtxt<'tcx>,
284    sym: Rc<Symbols>,
285}
286
287impl CheckSpirvAttrVisitor<'_> {
288    fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
289        let mut aggregated_attrs = AggregatedSpirvAttributes::default();
290
291        let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
292
293        let attrs = self.tcx.hir_attrs(hir_id);
294        for parse_attr_result in parse_attrs(attrs) {
295            let (span, parsed_attr) = match parse_attr_result {
296                Ok(span_and_parsed_attr) => span_and_parsed_attr,
297                Err((span, msg)) => {
298                    self.tcx.dcx().span_err(span, msg);
299                    continue;
300                }
301            };
302
303            /// Error newtype marker used below for readability.
304            struct Expected<T>(T);
305
306            let valid_target = match parsed_attr {
307                SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
308                    Target::Struct => {
309                        // FIXME(eddyb) further check type attribute validity,
310                        // e.g. layout, generics, other attributes, etc.
311                        Ok(())
312                    }
313
314                    _ => Err(Expected("struct")),
315                },
316
317                SpirvAttribute::Entry(_) => match target {
318                    Target::Fn
319                    | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
320                        // FIXME(eddyb) further check entry-point attribute validity,
321                        // e.g. signature, shouldn't have `#[inline]` or generics, etc.
322                        Ok(())
323                    }
324
325                    _ => Err(Expected("function")),
326                },
327
328                SpirvAttribute::StorageClass(_)
329                | SpirvAttribute::Builtin(_)
330                | SpirvAttribute::DescriptorSet(_)
331                | SpirvAttribute::Binding(_)
332                | SpirvAttribute::Location(_)
333                | SpirvAttribute::Flat
334                | SpirvAttribute::Invariant
335                | SpirvAttribute::PerPrimitiveExt
336                | SpirvAttribute::InputAttachmentIndex(_)
337                | SpirvAttribute::SpecConstant(_) => match target {
338                    Target::Param => {
339                        let parent_hir_id = self.tcx.parent_hir_id(hir_id);
340                        let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
341                            .filter_map(|r| r.ok())
342                            .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
343                        if !parent_is_entry_point {
344                            self.tcx.dcx().span_err(
345                                span,
346                                "attribute is only valid on a parameter of an entry-point function",
347                            );
348                        } else {
349                            // FIXME(eddyb) should we just remove all 5 of these storage class
350                            // attributes, instead of disallowing them here?
351                            if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
352                                let valid = match storage_class {
353                                    StorageClass::Input | StorageClass::Output => {
354                                        Err("is the default and should not be explicitly specified")
355                                    }
356
357                                    StorageClass::Private
358                                    | StorageClass::Function
359                                    | StorageClass::Generic => {
360                                        Err("can not be used as part of an entry's interface")
361                                    }
362
363                                    _ => Ok(()),
364                                };
365
366                                if let Err(msg) = valid {
367                                    self.tcx.dcx().span_err(
368                                        span,
369                                        format!("`{storage_class:?}` storage class {msg}"),
370                                    );
371                                }
372                            }
373                        }
374                        Ok(())
375                    }
376
377                    _ => Err(Expected("function parameter")),
378                },
379                SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
380                    match target {
381                        Target::Fn => Ok(()),
382                        _ => Err(Expected("function")),
383                    }
384                }
385            };
386            match valid_target {
387                Err(Expected(expected_target)) => {
388                    self.tcx.dcx().span_err(
389                        span,
390                        format!(
391                            "attribute is only valid on a {expected_target}, not on a {target}"
392                        ),
393                    );
394                }
395                Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
396                    Ok(()) => {}
397                    Err(MultipleAttrs {
398                        prev_span,
399                        category,
400                    }) => {
401                        self.tcx
402                            .dcx()
403                            .struct_span_err(
404                                span,
405                                format!("only one {category} attribute is allowed on a {target}"),
406                            )
407                            .with_span_note(prev_span, format!("previous {category} attribute"))
408                            .emit();
409                    }
410                },
411            }
412        }
413
414        // At this point we have all of the attributes (valid for this target),
415        // so we can perform further checks, emit warnings, etc.
416
417        if let Some(block_attr) = aggregated_attrs.block {
418            self.tcx.dcx().span_warn(
419                block_attr.span,
420                "#[spirv(block)] is no longer needed and should be removed",
421            );
422        }
423    }
424}
425
426// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
427impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
428    type NestedFilter = nested_filter::OnlyBodies;
429
430    fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
431        self.tcx
432    }
433
434    fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
435        let target = Target::from_item(item);
436        self.check_spirv_attributes(item.hir_id(), target);
437        intravisit::walk_item(self, item);
438    }
439
440    fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
441        let target = Target::from_generic_param(generic_param);
442        self.check_spirv_attributes(generic_param.hir_id, target);
443        intravisit::walk_generic_param(self, generic_param);
444    }
445
446    fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
447        let target = Target::from_trait_item(trait_item);
448        self.check_spirv_attributes(trait_item.hir_id(), target);
449        intravisit::walk_trait_item(self, trait_item);
450    }
451
452    fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
453        self.check_spirv_attributes(field.hir_id, Target::Field);
454        intravisit::walk_field_def(self, field);
455    }
456
457    fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
458        self.check_spirv_attributes(arm.hir_id, Target::Arm);
459        intravisit::walk_arm(self, arm);
460    }
461
462    fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
463        let target = Target::from_foreign_item(f_item);
464        self.check_spirv_attributes(f_item.hir_id(), target);
465        intravisit::walk_foreign_item(self, f_item);
466    }
467
468    fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
469        let target = target_from_impl_item(self.tcx, impl_item);
470        self.check_spirv_attributes(impl_item.hir_id(), target);
471        intravisit::walk_impl_item(self, impl_item);
472    }
473
474    fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
475        // When checking statements ignore expressions, they will be checked later.
476        if let hir::StmtKind::Let(l) = stmt.kind {
477            self.check_spirv_attributes(l.hir_id, Target::Statement);
478        }
479        intravisit::walk_stmt(self, stmt);
480    }
481
482    fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
483        let target = match expr.kind {
484            hir::ExprKind::Closure { .. } => Target::Closure,
485            _ => Target::Expression,
486        };
487
488        self.check_spirv_attributes(expr.hir_id, target);
489        intravisit::walk_expr(self, expr);
490    }
491
492    fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
493        self.check_spirv_attributes(variant.hir_id, Target::Variant);
494        intravisit::walk_variant(self, variant);
495    }
496
497    fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
498        self.check_spirv_attributes(param.hir_id, Target::Param);
499
500        intravisit::walk_param(self, param);
501    }
502}
503
504// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
505fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
506    let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
507        tcx,
508        sym: Symbols::get(),
509    };
510    tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
511    if module_def_id.is_top_level_module() {
512        check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
513    }
514}
515
516pub(crate) fn provide(providers: &mut Providers) {
517    *providers = Providers {
518        check_mod_attrs: |tcx, module_def_id| {
519            // Run both the default checks, and our `#[spirv(...)]` ones.
520            (rustc_interface::DEFAULT_QUERY_PROVIDERS
521                .queries
522                .check_mod_attrs)(tcx, module_def_id);
523            check_mod_attrs(tcx, module_def_id);
524        },
525        ..*providers
526    };
527}
528
529// FIXME(eddyb) find something nicer for the error type.
530type ParseAttrError = (Span, String);
531
532#[allow(clippy::get_first)]
533fn parse_attrs_for_checking<'sym, 'attr, I>(
534    sym: &'sym Symbols,
535    attrs: I,
536) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'sym
537where
538    I: IntoIterator<Item = &'attr Attribute> + 'sym,
539    I::IntoIter: 'sym,
540    'attr: 'sym,
541{
542    attrs
543        .into_iter()
544        .map(move |attr| {
545            // parse the #[rust_gpu::spirv(...)] attr and return the inner list
546            match attr {
547                Attribute::Unparsed(item) => {
548                    // #[...]
549                    let s = &item.path.segments;
550                    if let Some(rust_gpu) = s.get(0) && *rust_gpu == sym.rust_gpu {
551                        // #[rust_gpu ...]
552                        match s.get(1) {
553                            Some(command) if *command == sym.spirv_attr_with_version => {
554                                // #[rust_gpu::spirv ...]
555                                if let Some(args) = attr.meta_item_list() {
556                                    // #[rust_gpu::spirv(...)]
557                                    Ok(parse_spirv_attr(sym, args.iter()))
558                                } else {
559                                    // #[rust_gpu::spirv]
560                                    Err((
561                                        attr.span(),
562                                        "#[spirv(..)] attribute must have at least one argument"
563                                            .to_string(),
564                                    ))
565                                }
566                            }
567                            Some(command) if *command == sym.vector => {
568                                // #[rust_gpu::vector ...]
569                                match s.get(2) {
570                                    // #[rust_gpu::vector::v1]
571                                    Some(version) if *version == sym.v1 => {
572                                        Ok(SmallVec::from_iter([
573                                            Ok((attr.span(), SpirvAttribute::IntrinsicType(IntrinsicType::Vector)))
574                                        ]))
575                                    },
576                                    _ => Err((
577                                        attr.span(),
578                                        "unknown `rust_gpu::vector` version, expected `rust_gpu::vector::v1`"
579                                            .to_string(),
580                                    )),
581                                }
582                            }
583                            _ => {
584                                // #[rust_gpu::...] but not a know version
585                                let spirv = sym.spirv_attr_with_version.as_str();
586                                Err((
587                                    attr.span(),
588                                    format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
589                                Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
590                                ))
591                            }
592                        }
593                    } else {
594                        // #[...] but not #[rust_gpu ...]
595                        Ok(Default::default())
596                    }
597                }
598                Attribute::Parsed(_) => Ok(Default::default()),
599            }
600        })
601        .flat_map(|result| {
602            result
603                .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
604                .into_iter()
605        })
606}
607
608fn parse_spirv_attr<'a>(
609    sym: &Symbols,
610    iter: impl Iterator<Item = &'a MetaItemInner>,
611) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
612    iter.map(|arg| {
613        let span = arg.span();
614        let parsed_attr =
615            if arg.has_name(sym.descriptor_set) {
616                SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
617            } else if arg.has_name(sym.binding) {
618                SpirvAttribute::Binding(parse_attr_int_value(arg)?)
619            } else if arg.has_name(sym.location) {
620                SpirvAttribute::Location(parse_attr_int_value(arg)?)
621            } else if arg.has_name(sym.input_attachment_index) {
622                SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
623            } else if arg.has_name(sym.spec_constant) {
624                SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
625            } else {
626                let name = match arg.ident() {
627                    Some(i) => i,
628                    None => {
629                        return Err((
630                            span,
631                            "#[spirv(..)] attribute argument must be single identifier".to_string(),
632                        ));
633                    }
634                };
635                sym.attributes.get(&name.name).map_or_else(
636                    || Err((name.span, "unknown argument to spirv attribute".to_string())),
637                    |a| {
638                        Ok(match a {
639                            SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
640                                parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
641                            ),
642                            _ => a.clone(),
643                        })
644                    },
645                )?
646            };
647        Ok((span, parsed_attr))
648    })
649    .collect()
650}
651
652fn parse_spec_constant_attr(
653    sym: &Symbols,
654    arg: &MetaItemInner,
655) -> Result<SpecConstant, ParseAttrError> {
656    let mut id = None;
657    let mut default = None;
658
659    if let Some(attrs) = arg.meta_item_list() {
660        for attr in attrs {
661            if attr.has_name(sym.id) {
662                if id.is_none() {
663                    id = Some(parse_attr_int_value(attr)?);
664                } else {
665                    return Err((attr.span(), "`id` may only be specified once".into()));
666                }
667            } else if attr.has_name(sym.default) {
668                if default.is_none() {
669                    default = Some(parse_attr_int_value(attr)?);
670                } else {
671                    return Err((attr.span(), "`default` may only be specified once".into()));
672                }
673            } else {
674                return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
675            }
676        }
677    }
678    Ok(SpecConstant {
679        id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
680        default,
681        // to be set later
682        array_count: None,
683    })
684}
685
686fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
687    let arg = match arg.meta_item() {
688        Some(arg) => arg,
689        None => return Err((arg.span(), "attribute must have value".to_string())),
690    };
691    match arg.name_value_literal() {
692        Some(&MetaItemLit {
693            kind: LitKind::Int(x, ..),
694            ..
695        }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
696        _ => Err((arg.span, "attribute value must be integer".to_string())),
697    }
698}
699
700fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
701    let arg = match arg.meta_item() {
702        Some(arg) => arg,
703        None => return Err((arg.span(), "attribute must have value".to_string())),
704    };
705    match arg.meta_item_list() {
706        Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
707            let mut local_size = [1; 3];
708            for (idx, lit) in tuple.iter().enumerate() {
709                match lit {
710                    MetaItemInner::Lit(MetaItemLit {
711                                           kind: LitKind::Int(x, ..),
712                                           ..
713                                       }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
714                    _ => return Err((lit.span(), "must be a u32 literal".to_string())),
715                }
716            }
717            Ok(local_size)
718        }
719        Some([]) => Err((
720            arg.span,
721            "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
722        )),
723        Some(tuple) if tuple.len() > 3 => Err((
724            arg.span,
725            "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
726        )),
727        _ => Err((
728            arg.span,
729            "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
730        )),
731    }
732}
733
734// for a given entry, gather up the additional attributes
735// in this case ExecutionMode's, some have extra arguments
736// others are specified with x, y, or z components
737// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
738fn parse_entry_attrs(
739    sym: &Symbols,
740    arg: &MetaItemInner,
741    name: &Ident,
742    execution_model: ExecutionModel,
743) -> Result<Entry, ParseAttrError> {
744    use ExecutionMode::*;
745    use ExecutionModel::*;
746    let mut entry = Entry::from(execution_model);
747    let mut origin_mode: Option<ExecutionMode> = None;
748    let mut local_size: Option<[u32; 3]> = None;
749    let mut local_size_hint: Option<[u32; 3]> = None;
750    // Reserved
751    //let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
752    if let Some(attrs) = arg.meta_item_list() {
753        for attr in attrs {
754            if let Some(attr_name) = attr.ident() {
755                if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
756                {
757                    use crate::symbols::ExecutionModeExtraDim::*;
758                    let val = match extra_dim {
759                        None | Tuple => Option::None,
760                        _ => Some(parse_attr_int_value(attr)?),
761                    };
762                    match execution_mode {
763                        OriginUpperLeft | OriginLowerLeft => {
764                            origin_mode.replace(*execution_mode);
765                        }
766                        LocalSize => {
767                            if local_size.is_none() {
768                                local_size.replace(parse_local_size_attr(attr)?);
769                            } else {
770                                return Err((
771                                    attr_name.span,
772                                    String::from(
773                                        "`#[spirv(compute(threads))]` may only be specified once",
774                                    ),
775                                ));
776                            }
777                        }
778                        LocalSizeHint => {
779                            let val = val.unwrap();
780                            if local_size_hint.is_none() {
781                                local_size_hint.replace([1, 1, 1]);
782                            }
783                            let local_size_hint = local_size_hint.as_mut().unwrap();
784                            match extra_dim {
785                                X => {
786                                    local_size_hint[0] = val;
787                                }
788                                Y => {
789                                    local_size_hint[1] = val;
790                                }
791                                Z => {
792                                    local_size_hint[2] = val;
793                                }
794                                _ => unreachable!(),
795                            }
796                        }
797                        // Reserved
798                        /*MaxWorkgroupSizeINTEL => {
799                            let val = val.unwrap();
800                            if max_workgroup_size_intel.is_none() {
801                                max_workgroup_size_intel.replace([1, 1, 1]);
802                            }
803                            let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
804                                .unwrap();
805                            match extra_dim {
806                                X => {
807                                    max_workgroup_size_intel[0] = val;
808                                },
809                                Y => {
810                                    max_workgroup_size_intel[1] = val;
811                                },
812                                Z => {
813                                    max_workgroup_size_intel[2] = val;
814                                },
815                                _ => unreachable!(),
816                            }
817                        },*/
818                        _ => {
819                            if let Some(val) = val {
820                                entry
821                                    .execution_modes
822                                    .push((*execution_mode, ExecutionModeExtra::new([val])));
823                            } else {
824                                entry
825                                    .execution_modes
826                                    .push((*execution_mode, ExecutionModeExtra::new([])));
827                            }
828                        }
829                    }
830                } else if attr_name.name == sym.entry_point_name {
831                    match attr.value_str() {
832                        Some(sym) => {
833                            entry.name = Some(sym);
834                        }
835                        None => {
836                            return Err((
837                                attr_name.span,
838                                format!(
839                                    "#[spirv({name}(..))] unknown attribute argument {attr_name}"
840                                ),
841                            ));
842                        }
843                    }
844                } else {
845                    return Err((
846                        attr_name.span,
847                        format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
848                    ));
849                }
850            } else {
851                return Err((
852                    arg.span(),
853                    format!("#[spirv({name}(..))] attribute argument must be single identifier"),
854                ));
855            }
856        }
857    }
858    match entry.execution_model {
859        Fragment => {
860            let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
861            entry
862                .execution_modes
863                .push((origin_mode, ExecutionModeExtra::new([])));
864        }
865        GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
866            if let Some(local_size) = local_size {
867                entry
868                    .execution_modes
869                    .push((LocalSize, ExecutionModeExtra::new(local_size)));
870            } else {
871                return Err((
872                    arg.span(),
873                    String::from(
874                        "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
875                    ),
876                ));
877            }
878        }
879        //TODO: Cover more defaults
880        _ => {}
881    }
882    Ok(entry)
883}