rustc_codegen_spirv/
attr.rs

1//! `#[spirv(...)]` attribute support.
2//!
3//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
4
5use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
9use rustc_hir as hir;
10use rustc_hir::def_id::LocalModDefId;
11use rustc_hir::intravisit::{self, Visitor};
12use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
13use rustc_middle::hir::nested_filter;
14use rustc_middle::query::Providers;
15use rustc_middle::ty::TyCtxt;
16use rustc_span::{Ident, Span, Symbol};
17use smallvec::SmallVec;
18use std::rc::Rc;
19
20// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
21#[derive(Copy, Clone, Debug)]
22pub struct ExecutionModeExtra {
23    args: [u32; 3],
24    len: u8,
25}
26
27impl ExecutionModeExtra {
28    pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
29        let _args = args.as_ref();
30        let mut args = [0; 3];
31        args[.._args.len()].copy_from_slice(_args);
32        let len = _args.len() as u8;
33        Self { args, len }
34    }
35}
36
37impl AsRef<[u32]> for ExecutionModeExtra {
38    fn as_ref(&self) -> &[u32] {
39        &self.args[..self.len as _]
40    }
41}
42
43#[derive(Clone, Debug)]
44pub struct Entry {
45    pub execution_model: ExecutionModel,
46    pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
47    pub name: Option<Symbol>,
48}
49
50impl From<ExecutionModel> for Entry {
51    fn from(execution_model: ExecutionModel) -> Self {
52        Self {
53            execution_model,
54            execution_modes: Vec::new(),
55            name: None,
56        }
57    }
58}
59
60/// `struct` types that are used to represent special SPIR-V types.
61#[derive(Debug, Clone)]
62pub enum IntrinsicType {
63    GenericImageType,
64    Sampler,
65    AccelerationStructureKhr,
66    SampledImage,
67    RayQueryKhr,
68    RuntimeArray,
69    TypedBuffer,
70    Matrix,
71}
72
73#[derive(Copy, Clone, Debug, PartialEq, Eq)]
74pub struct SpecConstant {
75    pub id: u32,
76    pub default: Option<u32>,
77}
78
79// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
80// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
81#[derive(Debug, Clone)]
82pub enum SpirvAttribute {
83    // `struct` attributes:
84    IntrinsicType(IntrinsicType),
85    Block,
86
87    // `fn` attributes:
88    Entry(Entry),
89
90    // (entry) `fn` parameter attributes:
91    StorageClass(StorageClass),
92    Builtin(BuiltIn),
93    DescriptorSet(u32),
94    Binding(u32),
95    Flat,
96    PerPrimitiveExt,
97    Invariant,
98    InputAttachmentIndex(u32),
99    SpecConstant(SpecConstant),
100
101    // `fn`/closure attributes:
102    BufferLoadIntrinsic,
103    BufferStoreIntrinsic,
104}
105
106// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
107// field name instead of `node` (which feels inadequate in this context).
108#[derive(Copy, Clone)]
109pub struct Spanned<T> {
110    pub value: T,
111    pub span: Span,
112}
113
114/// Condensed version of a `SpirvAttribute` list, but only keeping one value per
115/// variant of `SpirvAttribute`, and treating multiple such attributes an error.
116// FIXME(eddyb) should this and `fn try_insert_attr` below be generated by a macro?
117#[derive(Default)]
118pub struct AggregatedSpirvAttributes {
119    // `struct` attributes:
120    pub intrinsic_type: Option<Spanned<IntrinsicType>>,
121    pub block: Option<Spanned<()>>,
122
123    // `fn` attributes:
124    pub entry: Option<Spanned<Entry>>,
125
126    // (entry) `fn` parameter attributes:
127    pub storage_class: Option<Spanned<StorageClass>>,
128    pub builtin: Option<Spanned<BuiltIn>>,
129    pub descriptor_set: Option<Spanned<u32>>,
130    pub binding: Option<Spanned<u32>>,
131    pub flat: Option<Spanned<()>>,
132    pub invariant: Option<Spanned<()>>,
133    pub per_primitive_ext: Option<Spanned<()>>,
134    pub input_attachment_index: Option<Spanned<u32>>,
135    pub spec_constant: Option<Spanned<SpecConstant>>,
136
137    // `fn`/closure attributes:
138    pub buffer_load_intrinsic: Option<Spanned<()>>,
139    pub buffer_store_intrinsic: Option<Spanned<()>>,
140}
141
142struct MultipleAttrs {
143    prev_span: Span,
144    category: &'static str,
145}
146
147impl AggregatedSpirvAttributes {
148    /// Compute `AggregatedSpirvAttributes` for use during codegen.
149    ///
150    /// Any errors for malformed/duplicate attributes will have been reported
151    /// prior to codegen, by the `attr` check pass.
152    pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
153        let mut aggregated_attrs = Self::default();
154
155        // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
156        // to see an attribute error, it will cause an ICE instead.
157        for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
158            let (span, parsed_attr) = match parse_attr_result {
159                Ok(span_and_parsed_attr) => span_and_parsed_attr,
160                Err((span, msg)) => {
161                    cx.tcx.dcx().span_delayed_bug(span, msg);
162                    continue;
163                }
164            };
165            match aggregated_attrs.try_insert_attr(parsed_attr, span) {
166                Ok(()) => {}
167                Err(MultipleAttrs {
168                    prev_span: _,
169                    category,
170                }) => {
171                    cx.tcx
172                        .dcx()
173                        .span_delayed_bug(span, format!("multiple {category} attributes"));
174                }
175            }
176        }
177
178        aggregated_attrs
179    }
180
181    fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
182        fn try_insert<T>(
183            slot: &mut Option<Spanned<T>>,
184            value: T,
185            span: Span,
186            category: &'static str,
187        ) -> Result<(), MultipleAttrs> {
188            if let Some(prev) = slot {
189                Err(MultipleAttrs {
190                    prev_span: prev.span,
191                    category,
192                })
193            } else {
194                *slot = Some(Spanned { value, span });
195                Ok(())
196            }
197        }
198
199        use SpirvAttribute::*;
200        match attr {
201            IntrinsicType(value) => {
202                try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
203            }
204            Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
205            Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
206            StorageClass(value) => {
207                try_insert(&mut self.storage_class, value, span, "storage class")
208            }
209            Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
210            DescriptorSet(value) => try_insert(
211                &mut self.descriptor_set,
212                value,
213                span,
214                "#[spirv(descriptor_set)]",
215            ),
216            Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
217            Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
218            Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
219            PerPrimitiveExt => try_insert(
220                &mut self.per_primitive_ext,
221                (),
222                span,
223                "#[spirv(per_primitive_ext)]",
224            ),
225            InputAttachmentIndex(value) => try_insert(
226                &mut self.input_attachment_index,
227                value,
228                span,
229                "#[spirv(attachment_index)]",
230            ),
231            SpecConstant(value) => try_insert(
232                &mut self.spec_constant,
233                value,
234                span,
235                "#[spirv(spec_constant)]",
236            ),
237            BufferLoadIntrinsic => try_insert(
238                &mut self.buffer_load_intrinsic,
239                (),
240                span,
241                "#[spirv(buffer_load_intrinsic)]",
242            ),
243            BufferStoreIntrinsic => try_insert(
244                &mut self.buffer_store_intrinsic,
245                (),
246                span,
247                "#[spirv(buffer_store_intrinsic)]",
248            ),
249        }
250    }
251}
252
253// FIXME(eddyb) make this reusable from somewhere in `rustc`.
254fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
255    match impl_item.kind {
256        hir::ImplItemKind::Const(..) => Target::AssocConst,
257        hir::ImplItemKind::Fn(..) => {
258            let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
259            let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
260            let containing_impl_is_for_trait = match &containing_item.kind {
261                hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
262                _ => unreachable!("parent of an ImplItem must be an Impl"),
263            };
264            if containing_impl_is_for_trait {
265                Target::Method(MethodKind::Trait { body: true })
266            } else {
267                Target::Method(MethodKind::Inherent)
268            }
269        }
270        hir::ImplItemKind::Type(..) => Target::AssocTy,
271    }
272}
273
274struct CheckSpirvAttrVisitor<'tcx> {
275    tcx: TyCtxt<'tcx>,
276    sym: Rc<Symbols>,
277}
278
279impl CheckSpirvAttrVisitor<'_> {
280    fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
281        let mut aggregated_attrs = AggregatedSpirvAttributes::default();
282
283        let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
284
285        let attrs = self.tcx.hir_attrs(hir_id);
286        for parse_attr_result in parse_attrs(attrs) {
287            let (span, parsed_attr) = match parse_attr_result {
288                Ok(span_and_parsed_attr) => span_and_parsed_attr,
289                Err((span, msg)) => {
290                    self.tcx.dcx().span_err(span, msg);
291                    continue;
292                }
293            };
294
295            /// Error newtype marker used below for readability.
296            struct Expected<T>(T);
297
298            let valid_target = match parsed_attr {
299                SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
300                    Target::Struct => {
301                        // FIXME(eddyb) further check type attribute validity,
302                        // e.g. layout, generics, other attributes, etc.
303                        Ok(())
304                    }
305
306                    _ => Err(Expected("struct")),
307                },
308
309                SpirvAttribute::Entry(_) => match target {
310                    Target::Fn
311                    | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
312                        // FIXME(eddyb) further check entry-point attribute validity,
313                        // e.g. signature, shouldn't have `#[inline]` or generics, etc.
314                        Ok(())
315                    }
316
317                    _ => Err(Expected("function")),
318                },
319
320                SpirvAttribute::StorageClass(_)
321                | SpirvAttribute::Builtin(_)
322                | SpirvAttribute::DescriptorSet(_)
323                | SpirvAttribute::Binding(_)
324                | SpirvAttribute::Flat
325                | SpirvAttribute::Invariant
326                | SpirvAttribute::PerPrimitiveExt
327                | SpirvAttribute::InputAttachmentIndex(_)
328                | SpirvAttribute::SpecConstant(_) => match target {
329                    Target::Param => {
330                        let parent_hir_id = self.tcx.parent_hir_id(hir_id);
331                        let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
332                            .filter_map(|r| r.ok())
333                            .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
334                        if !parent_is_entry_point {
335                            self.tcx.dcx().span_err(
336                                span,
337                                "attribute is only valid on a parameter of an entry-point function",
338                            );
339                        } else {
340                            // FIXME(eddyb) should we just remove all 5 of these storage class
341                            // attributes, instead of disallowing them here?
342                            if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
343                                let valid = match storage_class {
344                                    StorageClass::Input | StorageClass::Output => {
345                                        Err("is the default and should not be explicitly specified")
346                                    }
347
348                                    StorageClass::Private
349                                    | StorageClass::Function
350                                    | StorageClass::Generic => {
351                                        Err("can not be used as part of an entry's interface")
352                                    }
353
354                                    _ => Ok(()),
355                                };
356
357                                if let Err(msg) = valid {
358                                    self.tcx.dcx().span_err(
359                                        span,
360                                        format!("`{storage_class:?}` storage class {msg}"),
361                                    );
362                                }
363                            }
364                        }
365                        Ok(())
366                    }
367
368                    _ => Err(Expected("function parameter")),
369                },
370                SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
371                    match target {
372                        Target::Fn => Ok(()),
373                        _ => Err(Expected("function")),
374                    }
375                }
376            };
377            match valid_target {
378                Err(Expected(expected_target)) => {
379                    self.tcx.dcx().span_err(
380                        span,
381                        format!(
382                            "attribute is only valid on a {expected_target}, not on a {target}"
383                        ),
384                    );
385                }
386                Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
387                    Ok(()) => {}
388                    Err(MultipleAttrs {
389                        prev_span,
390                        category,
391                    }) => {
392                        self.tcx
393                            .dcx()
394                            .struct_span_err(
395                                span,
396                                format!("only one {category} attribute is allowed on a {target}"),
397                            )
398                            .with_span_note(prev_span, format!("previous {category} attribute"))
399                            .emit();
400                    }
401                },
402            }
403        }
404
405        // At this point we have all of the attributes (valid for this target),
406        // so we can perform further checks, emit warnings, etc.
407
408        if let Some(block_attr) = aggregated_attrs.block {
409            self.tcx.dcx().span_warn(
410                block_attr.span,
411                "#[spirv(block)] is no longer needed and should be removed",
412            );
413        }
414    }
415}
416
417// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
418impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
419    type NestedFilter = nested_filter::OnlyBodies;
420
421    fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
422        self.tcx
423    }
424
425    fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
426        let target = Target::from_item(item);
427        self.check_spirv_attributes(item.hir_id(), target);
428        intravisit::walk_item(self, item);
429    }
430
431    fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
432        let target = Target::from_generic_param(generic_param);
433        self.check_spirv_attributes(generic_param.hir_id, target);
434        intravisit::walk_generic_param(self, generic_param);
435    }
436
437    fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
438        let target = Target::from_trait_item(trait_item);
439        self.check_spirv_attributes(trait_item.hir_id(), target);
440        intravisit::walk_trait_item(self, trait_item);
441    }
442
443    fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
444        self.check_spirv_attributes(field.hir_id, Target::Field);
445        intravisit::walk_field_def(self, field);
446    }
447
448    fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
449        self.check_spirv_attributes(arm.hir_id, Target::Arm);
450        intravisit::walk_arm(self, arm);
451    }
452
453    fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
454        let target = Target::from_foreign_item(f_item);
455        self.check_spirv_attributes(f_item.hir_id(), target);
456        intravisit::walk_foreign_item(self, f_item);
457    }
458
459    fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
460        let target = target_from_impl_item(self.tcx, impl_item);
461        self.check_spirv_attributes(impl_item.hir_id(), target);
462        intravisit::walk_impl_item(self, impl_item);
463    }
464
465    fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
466        // When checking statements ignore expressions, they will be checked later.
467        if let hir::StmtKind::Let(l) = stmt.kind {
468            self.check_spirv_attributes(l.hir_id, Target::Statement);
469        }
470        intravisit::walk_stmt(self, stmt);
471    }
472
473    fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
474        let target = match expr.kind {
475            hir::ExprKind::Closure { .. } => Target::Closure,
476            _ => Target::Expression,
477        };
478
479        self.check_spirv_attributes(expr.hir_id, target);
480        intravisit::walk_expr(self, expr);
481    }
482
483    fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
484        self.check_spirv_attributes(variant.hir_id, Target::Variant);
485        intravisit::walk_variant(self, variant);
486    }
487
488    fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
489        self.check_spirv_attributes(param.hir_id, Target::Param);
490
491        intravisit::walk_param(self, param);
492    }
493}
494
495// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
496fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
497    let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
498        tcx,
499        sym: Symbols::get(),
500    };
501    tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
502    if module_def_id.is_top_level_module() {
503        check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
504    }
505}
506
507pub(crate) fn provide(providers: &mut Providers) {
508    *providers = Providers {
509        check_mod_attrs: |tcx, module_def_id| {
510            // Run both the default checks, and our `#[spirv(...)]` ones.
511            (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
512            check_mod_attrs(tcx, module_def_id);
513        },
514        ..*providers
515    };
516}
517
518// FIXME(eddyb) find something nicer for the error type.
519type ParseAttrError = (Span, String);
520
521#[allow(clippy::get_first)]
522fn parse_attrs_for_checking<'a>(
523    sym: &'a Symbols,
524    attrs: &'a [Attribute],
525) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
526    attrs
527        .iter()
528        .map(move |attr| {
529            // parse the #[rust_gpu::spirv(...)] attr and return the inner list
530            match attr {
531                Attribute::Unparsed(item) => {
532                    // #[...]
533                    let s = &item.path.segments;
534                    if let Some(rust_gpu) = s.get(0) && rust_gpu.name == sym.rust_gpu {
535                        // #[rust_gpu ...]
536                        match s.get(1) {
537                            Some(command) if command.name == sym.spirv_attr_with_version => {
538                                // #[rust_gpu::spirv ...]
539                                if let Some(args) = attr.meta_item_list() {
540                                    // #[rust_gpu::spirv(...)]
541                                    Ok(parse_spirv_attr(sym, args.iter()))
542                                } else {
543                                    // #[rust_gpu::spirv]
544                                    Err((
545                                        attr.span(),
546                                        "#[spirv(..)] attribute must have at least one argument"
547                                            .to_string(),
548                                    ))
549                                }
550                            }
551                            _ => {
552                                // #[rust_gpu::...] but not a know version
553                                let spirv = sym.spirv_attr_with_version.as_str();
554                                Err((
555                                    attr.span(),
556                                    format!("unknown `rust_gpu` attribute, expected `rust_gpu::{spirv}`. \
557                                Do the versions of `spirv-std` and `rustc_codegen_spirv` match?"),
558                                ))
559                            }
560                        }
561                    } else {
562                        // #[...] but not #[rust_gpu ...]
563                        Ok(Default::default())
564                    }
565                }
566                Attribute::Parsed(_) => Ok(Default::default()),
567            }
568        })
569        .flat_map(|result| {
570            result
571                .unwrap_or_else(|err| SmallVec::from_iter([Err(err)]))
572                .into_iter()
573        })
574}
575
576fn parse_spirv_attr<'a>(
577    sym: &Symbols,
578    iter: impl Iterator<Item = &'a MetaItemInner>,
579) -> SmallVec<[Result<(Span, SpirvAttribute), ParseAttrError>; 4]> {
580    iter.map(|arg| {
581        let span = arg.span();
582        let parsed_attr =
583            if arg.has_name(sym.descriptor_set) {
584                SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
585            } else if arg.has_name(sym.binding) {
586                SpirvAttribute::Binding(parse_attr_int_value(arg)?)
587            } else if arg.has_name(sym.input_attachment_index) {
588                SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
589            } else if arg.has_name(sym.spec_constant) {
590                SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
591            } else {
592                let name = match arg.ident() {
593                    Some(i) => i,
594                    None => {
595                        return Err((
596                            span,
597                            "#[spirv(..)] attribute argument must be single identifier".to_string(),
598                        ));
599                    }
600                };
601                sym.attributes.get(&name.name).map_or_else(
602                    || Err((name.span, "unknown argument to spirv attribute".to_string())),
603                    |a| {
604                        Ok(match a {
605                            SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
606                                parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
607                            ),
608                            _ => a.clone(),
609                        })
610                    },
611                )?
612            };
613        Ok((span, parsed_attr))
614    })
615    .collect()
616}
617
618fn parse_spec_constant_attr(
619    sym: &Symbols,
620    arg: &MetaItemInner,
621) -> Result<SpecConstant, ParseAttrError> {
622    let mut id = None;
623    let mut default = None;
624
625    if let Some(attrs) = arg.meta_item_list() {
626        for attr in attrs {
627            if attr.has_name(sym.id) {
628                if id.is_none() {
629                    id = Some(parse_attr_int_value(attr)?);
630                } else {
631                    return Err((attr.span(), "`id` may only be specified once".into()));
632                }
633            } else if attr.has_name(sym.default) {
634                if default.is_none() {
635                    default = Some(parse_attr_int_value(attr)?);
636                } else {
637                    return Err((attr.span(), "`default` may only be specified once".into()));
638                }
639            } else {
640                return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
641            }
642        }
643    }
644    Ok(SpecConstant {
645        id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
646        default,
647    })
648}
649
650fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
651    let arg = match arg.meta_item() {
652        Some(arg) => arg,
653        None => return Err((arg.span(), "attribute must have value".to_string())),
654    };
655    match arg.name_value_literal() {
656        Some(&MetaItemLit {
657            kind: LitKind::Int(x, ..),
658            ..
659        }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
660        _ => Err((arg.span, "attribute value must be integer".to_string())),
661    }
662}
663
664fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
665    let arg = match arg.meta_item() {
666        Some(arg) => arg,
667        None => return Err((arg.span(), "attribute must have value".to_string())),
668    };
669    match arg.meta_item_list() {
670        Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
671            let mut local_size = [1; 3];
672            for (idx, lit) in tuple.iter().enumerate() {
673                match lit {
674                    MetaItemInner::Lit(MetaItemLit {
675                                           kind: LitKind::Int(x, ..),
676                                           ..
677                                       }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
678                    _ => return Err((lit.span(), "must be a u32 literal".to_string())),
679                }
680            }
681            Ok(local_size)
682        }
683        Some([]) => Err((
684            arg.span,
685            "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
686        )),
687        Some(tuple) if tuple.len() > 3 => Err((
688            arg.span,
689            "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
690        )),
691        _ => Err((
692            arg.span,
693            "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
694        )),
695    }
696}
697
698// for a given entry, gather up the additional attributes
699// in this case ExecutionMode's, some have extra arguments
700// others are specified with x, y, or z components
701// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
702fn parse_entry_attrs(
703    sym: &Symbols,
704    arg: &MetaItemInner,
705    name: &Ident,
706    execution_model: ExecutionModel,
707) -> Result<Entry, ParseAttrError> {
708    use ExecutionMode::*;
709    use ExecutionModel::*;
710    let mut entry = Entry::from(execution_model);
711    let mut origin_mode: Option<ExecutionMode> = None;
712    let mut local_size: Option<[u32; 3]> = None;
713    let mut local_size_hint: Option<[u32; 3]> = None;
714    // Reserved
715    //let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
716    if let Some(attrs) = arg.meta_item_list() {
717        for attr in attrs {
718            if let Some(attr_name) = attr.ident() {
719                if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
720                {
721                    use crate::symbols::ExecutionModeExtraDim::*;
722                    let val = match extra_dim {
723                        None | Tuple => Option::None,
724                        _ => Some(parse_attr_int_value(attr)?),
725                    };
726                    match execution_mode {
727                        OriginUpperLeft | OriginLowerLeft => {
728                            origin_mode.replace(*execution_mode);
729                        }
730                        LocalSize => {
731                            if local_size.is_none() {
732                                local_size.replace(parse_local_size_attr(attr)?);
733                            } else {
734                                return Err((
735                                    attr_name.span,
736                                    String::from(
737                                        "`#[spirv(compute(threads))]` may only be specified once",
738                                    ),
739                                ));
740                            }
741                        }
742                        LocalSizeHint => {
743                            let val = val.unwrap();
744                            if local_size_hint.is_none() {
745                                local_size_hint.replace([1, 1, 1]);
746                            }
747                            let local_size_hint = local_size_hint.as_mut().unwrap();
748                            match extra_dim {
749                                X => {
750                                    local_size_hint[0] = val;
751                                }
752                                Y => {
753                                    local_size_hint[1] = val;
754                                }
755                                Z => {
756                                    local_size_hint[2] = val;
757                                }
758                                _ => unreachable!(),
759                            }
760                        }
761                        // Reserved
762                        /*MaxWorkgroupSizeINTEL => {
763                            let val = val.unwrap();
764                            if max_workgroup_size_intel.is_none() {
765                                max_workgroup_size_intel.replace([1, 1, 1]);
766                            }
767                            let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
768                                .unwrap();
769                            match extra_dim {
770                                X => {
771                                    max_workgroup_size_intel[0] = val;
772                                },
773                                Y => {
774                                    max_workgroup_size_intel[1] = val;
775                                },
776                                Z => {
777                                    max_workgroup_size_intel[2] = val;
778                                },
779                                _ => unreachable!(),
780                            }
781                        },*/
782                        _ => {
783                            if let Some(val) = val {
784                                entry
785                                    .execution_modes
786                                    .push((*execution_mode, ExecutionModeExtra::new([val])));
787                            } else {
788                                entry
789                                    .execution_modes
790                                    .push((*execution_mode, ExecutionModeExtra::new([])));
791                            }
792                        }
793                    }
794                } else if attr_name.name == sym.entry_point_name {
795                    match attr.value_str() {
796                        Some(sym) => {
797                            entry.name = Some(sym);
798                        }
799                        None => {
800                            return Err((
801                                attr_name.span,
802                                format!(
803                                    "#[spirv({name}(..))] unknown attribute argument {attr_name}"
804                                ),
805                            ));
806                        }
807                    }
808                } else {
809                    return Err((
810                        attr_name.span,
811                        format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
812                    ));
813                }
814            } else {
815                return Err((
816                    arg.span(),
817                    format!("#[spirv({name}(..))] attribute argument must be single identifier"),
818                ));
819            }
820        }
821    }
822    match entry.execution_model {
823        Fragment => {
824            let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
825            entry
826                .execution_modes
827                .push((origin_mode, ExecutionModeExtra::new([])));
828        }
829        GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
830            if let Some(local_size) = local_size {
831                entry
832                    .execution_modes
833                    .push((LocalSize, ExecutionModeExtra::new(local_size)));
834            } else {
835                return Err((
836                    arg.span(),
837                    String::from(
838                        "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
839                    ),
840                ));
841            }
842        }
843        //TODO: Cover more defaults
844        _ => {}
845    }
846    Ok(entry)
847}