rustc_codegen_spirv/
attr.rs

1//! `#[spirv(...)]` attribute support.
2//!
3//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.
4
5use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_hir as hir;
9use rustc_hir::def_id::LocalModDefId;
10use rustc_hir::intravisit::{self, Visitor};
11use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
12use rustc_middle::hir::nested_filter;
13use rustc_middle::query::Providers;
14use rustc_middle::ty::TyCtxt;
15use rustc_span::{Span, Symbol};
16use std::rc::Rc;
17
18// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
19#[derive(Copy, Clone, Debug)]
20pub struct ExecutionModeExtra {
21    args: [u32; 3],
22    len: u8,
23}
24
25impl ExecutionModeExtra {
26    pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
27        let _args = args.as_ref();
28        let mut args = [0; 3];
29        args[.._args.len()].copy_from_slice(_args);
30        let len = _args.len() as u8;
31        Self { args, len }
32    }
33}
34
35impl AsRef<[u32]> for ExecutionModeExtra {
36    fn as_ref(&self) -> &[u32] {
37        &self.args[..self.len as _]
38    }
39}
40
41#[derive(Clone, Debug)]
42pub struct Entry {
43    pub execution_model: ExecutionModel,
44    pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
45    pub name: Option<Symbol>,
46}
47
48impl From<ExecutionModel> for Entry {
49    fn from(execution_model: ExecutionModel) -> Self {
50        Self {
51            execution_model,
52            execution_modes: Vec::new(),
53            name: None,
54        }
55    }
56}
57
58/// `struct` types that are used to represent special SPIR-V types.
59#[derive(Debug, Clone)]
60pub enum IntrinsicType {
61    GenericImageType,
62    Sampler,
63    AccelerationStructureKhr,
64    SampledImage,
65    RayQueryKhr,
66    RuntimeArray,
67    TypedBuffer,
68    Matrix,
69}
70
71#[derive(Copy, Clone, Debug, PartialEq, Eq)]
72pub struct SpecConstant {
73    pub id: u32,
74    pub default: Option<u32>,
75}
76
77// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
78// `tests/ui/spirv-attr` should be updated (and new ones added if necessary).
79#[derive(Debug, Clone)]
80pub enum SpirvAttribute {
81    // `struct` attributes:
82    IntrinsicType(IntrinsicType),
83    Block,
84
85    // `fn` attributes:
86    Entry(Entry),
87
88    // (entry) `fn` parameter attributes:
89    StorageClass(StorageClass),
90    Builtin(BuiltIn),
91    DescriptorSet(u32),
92    Binding(u32),
93    Flat,
94    PerPrimitiveExt,
95    Invariant,
96    InputAttachmentIndex(u32),
97    SpecConstant(SpecConstant),
98
99    // `fn`/closure attributes:
100    BufferLoadIntrinsic,
101    BufferStoreIntrinsic,
102}
103
104// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
105// field name instead of `node` (which feels inadequate in this context).
106#[derive(Copy, Clone)]
107pub struct Spanned<T> {
108    pub value: T,
109    pub span: Span,
110}
111
112/// Condensed version of a `SpirvAttribute` list, but only keeping one value per
113/// variant of `SpirvAttribute`, and treating multiple such attributes an error.
114// FIXME(eddyb) should this and `fn try_insert_attr` below be generated by a macro?
115#[derive(Default)]
116pub struct AggregatedSpirvAttributes {
117    // `struct` attributes:
118    pub intrinsic_type: Option<Spanned<IntrinsicType>>,
119    pub block: Option<Spanned<()>>,
120
121    // `fn` attributes:
122    pub entry: Option<Spanned<Entry>>,
123
124    // (entry) `fn` parameter attributes:
125    pub storage_class: Option<Spanned<StorageClass>>,
126    pub builtin: Option<Spanned<BuiltIn>>,
127    pub descriptor_set: Option<Spanned<u32>>,
128    pub binding: Option<Spanned<u32>>,
129    pub flat: Option<Spanned<()>>,
130    pub invariant: Option<Spanned<()>>,
131    pub per_primitive_ext: Option<Spanned<()>>,
132    pub input_attachment_index: Option<Spanned<u32>>,
133    pub spec_constant: Option<Spanned<SpecConstant>>,
134
135    // `fn`/closure attributes:
136    pub buffer_load_intrinsic: Option<Spanned<()>>,
137    pub buffer_store_intrinsic: Option<Spanned<()>>,
138}
139
140struct MultipleAttrs {
141    prev_span: Span,
142    category: &'static str,
143}
144
145impl AggregatedSpirvAttributes {
146    /// Compute `AggregatedSpirvAttributes` for use during codegen.
147    ///
148    /// Any errors for malformed/duplicate attributes will have been reported
149    /// prior to codegen, by the `attr` check pass.
150    pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
151        let mut aggregated_attrs = Self::default();
152
153        // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
154        // to see an attribute error, it will cause an ICE instead.
155        for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
156            let (span, parsed_attr) = match parse_attr_result {
157                Ok(span_and_parsed_attr) => span_and_parsed_attr,
158                Err((span, msg)) => {
159                    cx.tcx.dcx().span_delayed_bug(span, msg);
160                    continue;
161                }
162            };
163            match aggregated_attrs.try_insert_attr(parsed_attr, span) {
164                Ok(()) => {}
165                Err(MultipleAttrs {
166                    prev_span: _,
167                    category,
168                }) => {
169                    cx.tcx
170                        .dcx()
171                        .span_delayed_bug(span, format!("multiple {category} attributes"));
172                }
173            }
174        }
175
176        aggregated_attrs
177    }
178
179    fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
180        fn try_insert<T>(
181            slot: &mut Option<Spanned<T>>,
182            value: T,
183            span: Span,
184            category: &'static str,
185        ) -> Result<(), MultipleAttrs> {
186            if let Some(prev) = slot {
187                Err(MultipleAttrs {
188                    prev_span: prev.span,
189                    category,
190                })
191            } else {
192                *slot = Some(Spanned { value, span });
193                Ok(())
194            }
195        }
196
197        use SpirvAttribute::*;
198        match attr {
199            IntrinsicType(value) => {
200                try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
201            }
202            Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
203            Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
204            StorageClass(value) => {
205                try_insert(&mut self.storage_class, value, span, "storage class")
206            }
207            Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
208            DescriptorSet(value) => try_insert(
209                &mut self.descriptor_set,
210                value,
211                span,
212                "#[spirv(descriptor_set)]",
213            ),
214            Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
215            Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
216            Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
217            PerPrimitiveExt => try_insert(
218                &mut self.per_primitive_ext,
219                (),
220                span,
221                "#[spirv(per_primitive_ext)]",
222            ),
223            InputAttachmentIndex(value) => try_insert(
224                &mut self.input_attachment_index,
225                value,
226                span,
227                "#[spirv(attachment_index)]",
228            ),
229            SpecConstant(value) => try_insert(
230                &mut self.spec_constant,
231                value,
232                span,
233                "#[spirv(spec_constant)]",
234            ),
235            BufferLoadIntrinsic => try_insert(
236                &mut self.buffer_load_intrinsic,
237                (),
238                span,
239                "#[spirv(buffer_load_intrinsic)]",
240            ),
241            BufferStoreIntrinsic => try_insert(
242                &mut self.buffer_store_intrinsic,
243                (),
244                span,
245                "#[spirv(buffer_store_intrinsic)]",
246            ),
247        }
248    }
249}
250
251// FIXME(eddyb) make this reusable from somewhere in `rustc`.
252fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
253    match impl_item.kind {
254        hir::ImplItemKind::Const(..) => Target::AssocConst,
255        hir::ImplItemKind::Fn(..) => {
256            let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
257            let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
258            let containing_impl_is_for_trait = match &containing_item.kind {
259                hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
260                _ => unreachable!("parent of an ImplItem must be an Impl"),
261            };
262            if containing_impl_is_for_trait {
263                Target::Method(MethodKind::Trait { body: true })
264            } else {
265                Target::Method(MethodKind::Inherent)
266            }
267        }
268        hir::ImplItemKind::Type(..) => Target::AssocTy,
269    }
270}
271
272struct CheckSpirvAttrVisitor<'tcx> {
273    tcx: TyCtxt<'tcx>,
274    sym: Rc<Symbols>,
275}
276
277impl CheckSpirvAttrVisitor<'_> {
278    fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
279        let mut aggregated_attrs = AggregatedSpirvAttributes::default();
280
281        let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
282
283        let attrs = self.tcx.hir_attrs(hir_id);
284        for parse_attr_result in parse_attrs(attrs) {
285            let (span, parsed_attr) = match parse_attr_result {
286                Ok(span_and_parsed_attr) => span_and_parsed_attr,
287                Err((span, msg)) => {
288                    self.tcx.dcx().span_err(span, msg);
289                    continue;
290                }
291            };
292
293            /// Error newtype marker used below for readability.
294            struct Expected<T>(T);
295
296            let valid_target = match parsed_attr {
297                SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
298                    Target::Struct => {
299                        // FIXME(eddyb) further check type attribute validity,
300                        // e.g. layout, generics, other attributes, etc.
301                        Ok(())
302                    }
303
304                    _ => Err(Expected("struct")),
305                },
306
307                SpirvAttribute::Entry(_) => match target {
308                    Target::Fn
309                    | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
310                        // FIXME(eddyb) further check entry-point attribute validity,
311                        // e.g. signature, shouldn't have `#[inline]` or generics, etc.
312                        Ok(())
313                    }
314
315                    _ => Err(Expected("function")),
316                },
317
318                SpirvAttribute::StorageClass(_)
319                | SpirvAttribute::Builtin(_)
320                | SpirvAttribute::DescriptorSet(_)
321                | SpirvAttribute::Binding(_)
322                | SpirvAttribute::Flat
323                | SpirvAttribute::Invariant
324                | SpirvAttribute::PerPrimitiveExt
325                | SpirvAttribute::InputAttachmentIndex(_)
326                | SpirvAttribute::SpecConstant(_) => match target {
327                    Target::Param => {
328                        let parent_hir_id = self.tcx.parent_hir_id(hir_id);
329                        let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
330                            .filter_map(|r| r.ok())
331                            .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
332                        if !parent_is_entry_point {
333                            self.tcx.dcx().span_err(
334                                span,
335                                "attribute is only valid on a parameter of an entry-point function",
336                            );
337                        } else {
338                            // FIXME(eddyb) should we just remove all 5 of these storage class
339                            // attributes, instead of disallowing them here?
340                            if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
341                                let valid = match storage_class {
342                                    StorageClass::Input | StorageClass::Output => {
343                                        Err("is the default and should not be explicitly specified")
344                                    }
345
346                                    StorageClass::Private
347                                    | StorageClass::Function
348                                    | StorageClass::Generic => {
349                                        Err("can not be used as part of an entry's interface")
350                                    }
351
352                                    _ => Ok(()),
353                                };
354
355                                if let Err(msg) = valid {
356                                    self.tcx.dcx().span_err(
357                                        span,
358                                        format!("`{storage_class:?}` storage class {msg}"),
359                                    );
360                                }
361                            }
362                        }
363                        Ok(())
364                    }
365
366                    _ => Err(Expected("function parameter")),
367                },
368                SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
369                    match target {
370                        Target::Fn => Ok(()),
371                        _ => Err(Expected("function")),
372                    }
373                }
374            };
375            match valid_target {
376                Err(Expected(expected_target)) => {
377                    self.tcx.dcx().span_err(
378                        span,
379                        format!(
380                            "attribute is only valid on a {expected_target}, not on a {target}"
381                        ),
382                    );
383                }
384                Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
385                    Ok(()) => {}
386                    Err(MultipleAttrs {
387                        prev_span,
388                        category,
389                    }) => {
390                        self.tcx
391                            .dcx()
392                            .struct_span_err(
393                                span,
394                                format!("only one {category} attribute is allowed on a {target}"),
395                            )
396                            .with_span_note(prev_span, format!("previous {category} attribute"))
397                            .emit();
398                    }
399                },
400            }
401        }
402
403        // At this point we have all of the attributes (valid for this target),
404        // so we can perform further checks, emit warnings, etc.
405
406        if let Some(block_attr) = aggregated_attrs.block {
407            self.tcx.dcx().span_warn(
408                block_attr.span,
409                "#[spirv(block)] is no longer needed and should be removed",
410            );
411        }
412    }
413}
414
415// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
416impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
417    type NestedFilter = nested_filter::OnlyBodies;
418
419    fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
420        self.tcx
421    }
422
423    fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
424        let target = Target::from_item(item);
425        self.check_spirv_attributes(item.hir_id(), target);
426        intravisit::walk_item(self, item);
427    }
428
429    fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
430        let target = Target::from_generic_param(generic_param);
431        self.check_spirv_attributes(generic_param.hir_id, target);
432        intravisit::walk_generic_param(self, generic_param);
433    }
434
435    fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
436        let target = Target::from_trait_item(trait_item);
437        self.check_spirv_attributes(trait_item.hir_id(), target);
438        intravisit::walk_trait_item(self, trait_item);
439    }
440
441    fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
442        self.check_spirv_attributes(field.hir_id, Target::Field);
443        intravisit::walk_field_def(self, field);
444    }
445
446    fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
447        self.check_spirv_attributes(arm.hir_id, Target::Arm);
448        intravisit::walk_arm(self, arm);
449    }
450
451    fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
452        let target = Target::from_foreign_item(f_item);
453        self.check_spirv_attributes(f_item.hir_id(), target);
454        intravisit::walk_foreign_item(self, f_item);
455    }
456
457    fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
458        let target = target_from_impl_item(self.tcx, impl_item);
459        self.check_spirv_attributes(impl_item.hir_id(), target);
460        intravisit::walk_impl_item(self, impl_item);
461    }
462
463    fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
464        // When checking statements ignore expressions, they will be checked later.
465        if let hir::StmtKind::Let(l) = stmt.kind {
466            self.check_spirv_attributes(l.hir_id, Target::Statement);
467        }
468        intravisit::walk_stmt(self, stmt);
469    }
470
471    fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
472        let target = match expr.kind {
473            hir::ExprKind::Closure { .. } => Target::Closure,
474            _ => Target::Expression,
475        };
476
477        self.check_spirv_attributes(expr.hir_id, target);
478        intravisit::walk_expr(self, expr);
479    }
480
481    fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
482        self.check_spirv_attributes(variant.hir_id, Target::Variant);
483        intravisit::walk_variant(self, variant);
484    }
485
486    fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
487        self.check_spirv_attributes(param.hir_id, Target::Param);
488
489        intravisit::walk_param(self, param);
490    }
491}
492
493// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
494fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
495    let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
496        tcx,
497        sym: Symbols::get(),
498    };
499    tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
500    if module_def_id.is_top_level_module() {
501        check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
502    }
503}
504
505pub(crate) fn provide(providers: &mut Providers) {
506    *providers = Providers {
507        check_mod_attrs: |tcx, module_def_id| {
508            // Run both the default checks, and our `#[spirv(...)]` ones.
509            (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
510            check_mod_attrs(tcx, module_def_id);
511        },
512        ..*providers
513    };
514}