1use crate::codegen_cx::CodegenCx;
6use crate::symbols::Symbols;
7use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8use rustc_hir as hir;
9use rustc_hir::def_id::LocalModDefId;
10use rustc_hir::intravisit::{self, Visitor};
11use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
12use rustc_middle::hir::nested_filter;
13use rustc_middle::query::Providers;
14use rustc_middle::ty::TyCtxt;
15use rustc_span::{Span, Symbol};
16use std::rc::Rc;
17
18#[derive(Copy, Clone, Debug)]
20pub struct ExecutionModeExtra {
21 args: [u32; 3],
22 len: u8,
23}
24
25impl ExecutionModeExtra {
26 pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
27 let _args = args.as_ref();
28 let mut args = [0; 3];
29 args[.._args.len()].copy_from_slice(_args);
30 let len = _args.len() as u8;
31 Self { args, len }
32 }
33}
34
35impl AsRef<[u32]> for ExecutionModeExtra {
36 fn as_ref(&self) -> &[u32] {
37 &self.args[..self.len as _]
38 }
39}
40
41#[derive(Clone, Debug)]
42pub struct Entry {
43 pub execution_model: ExecutionModel,
44 pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
45 pub name: Option<Symbol>,
46}
47
48impl From<ExecutionModel> for Entry {
49 fn from(execution_model: ExecutionModel) -> Self {
50 Self {
51 execution_model,
52 execution_modes: Vec::new(),
53 name: None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub enum IntrinsicType {
61 GenericImageType,
62 Sampler,
63 AccelerationStructureKhr,
64 SampledImage,
65 RayQueryKhr,
66 RuntimeArray,
67 TypedBuffer,
68 Matrix,
69}
70
71#[derive(Copy, Clone, Debug, PartialEq, Eq)]
72pub struct SpecConstant {
73 pub id: u32,
74 pub default: Option<u32>,
75}
76
77#[derive(Debug, Clone)]
80pub enum SpirvAttribute {
81 IntrinsicType(IntrinsicType),
83 Block,
84
85 Entry(Entry),
87
88 StorageClass(StorageClass),
90 Builtin(BuiltIn),
91 DescriptorSet(u32),
92 Binding(u32),
93 Flat,
94 PerPrimitiveExt,
95 Invariant,
96 InputAttachmentIndex(u32),
97 SpecConstant(SpecConstant),
98
99 BufferLoadIntrinsic,
101 BufferStoreIntrinsic,
102}
103
104#[derive(Copy, Clone)]
107pub struct Spanned<T> {
108 pub value: T,
109 pub span: Span,
110}
111
112#[derive(Default)]
116pub struct AggregatedSpirvAttributes {
117 pub intrinsic_type: Option<Spanned<IntrinsicType>>,
119 pub block: Option<Spanned<()>>,
120
121 pub entry: Option<Spanned<Entry>>,
123
124 pub storage_class: Option<Spanned<StorageClass>>,
126 pub builtin: Option<Spanned<BuiltIn>>,
127 pub descriptor_set: Option<Spanned<u32>>,
128 pub binding: Option<Spanned<u32>>,
129 pub flat: Option<Spanned<()>>,
130 pub invariant: Option<Spanned<()>>,
131 pub per_primitive_ext: Option<Spanned<()>>,
132 pub input_attachment_index: Option<Spanned<u32>>,
133 pub spec_constant: Option<Spanned<SpecConstant>>,
134
135 pub buffer_load_intrinsic: Option<Spanned<()>>,
137 pub buffer_store_intrinsic: Option<Spanned<()>>,
138}
139
140struct MultipleAttrs {
141 prev_span: Span,
142 category: &'static str,
143}
144
145impl AggregatedSpirvAttributes {
146 pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
151 let mut aggregated_attrs = Self::default();
152
153 for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
156 let (span, parsed_attr) = match parse_attr_result {
157 Ok(span_and_parsed_attr) => span_and_parsed_attr,
158 Err((span, msg)) => {
159 cx.tcx.dcx().span_delayed_bug(span, msg);
160 continue;
161 }
162 };
163 match aggregated_attrs.try_insert_attr(parsed_attr, span) {
164 Ok(()) => {}
165 Err(MultipleAttrs {
166 prev_span: _,
167 category,
168 }) => {
169 cx.tcx
170 .dcx()
171 .span_delayed_bug(span, format!("multiple {category} attributes"));
172 }
173 }
174 }
175
176 aggregated_attrs
177 }
178
179 fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
180 fn try_insert<T>(
181 slot: &mut Option<Spanned<T>>,
182 value: T,
183 span: Span,
184 category: &'static str,
185 ) -> Result<(), MultipleAttrs> {
186 if let Some(prev) = slot {
187 Err(MultipleAttrs {
188 prev_span: prev.span,
189 category,
190 })
191 } else {
192 *slot = Some(Spanned { value, span });
193 Ok(())
194 }
195 }
196
197 use SpirvAttribute::*;
198 match attr {
199 IntrinsicType(value) => {
200 try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
201 }
202 Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
203 Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
204 StorageClass(value) => {
205 try_insert(&mut self.storage_class, value, span, "storage class")
206 }
207 Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
208 DescriptorSet(value) => try_insert(
209 &mut self.descriptor_set,
210 value,
211 span,
212 "#[spirv(descriptor_set)]",
213 ),
214 Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
215 Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
216 Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
217 PerPrimitiveExt => try_insert(
218 &mut self.per_primitive_ext,
219 (),
220 span,
221 "#[spirv(per_primitive_ext)]",
222 ),
223 InputAttachmentIndex(value) => try_insert(
224 &mut self.input_attachment_index,
225 value,
226 span,
227 "#[spirv(attachment_index)]",
228 ),
229 SpecConstant(value) => try_insert(
230 &mut self.spec_constant,
231 value,
232 span,
233 "#[spirv(spec_constant)]",
234 ),
235 BufferLoadIntrinsic => try_insert(
236 &mut self.buffer_load_intrinsic,
237 (),
238 span,
239 "#[spirv(buffer_load_intrinsic)]",
240 ),
241 BufferStoreIntrinsic => try_insert(
242 &mut self.buffer_store_intrinsic,
243 (),
244 span,
245 "#[spirv(buffer_store_intrinsic)]",
246 ),
247 }
248 }
249}
250
251fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
253 match impl_item.kind {
254 hir::ImplItemKind::Const(..) => Target::AssocConst,
255 hir::ImplItemKind::Fn(..) => {
256 let parent_owner_id = tcx.hir_get_parent_item(impl_item.hir_id());
257 let containing_item = tcx.hir_expect_item(parent_owner_id.def_id);
258 let containing_impl_is_for_trait = match &containing_item.kind {
259 hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
260 _ => unreachable!("parent of an ImplItem must be an Impl"),
261 };
262 if containing_impl_is_for_trait {
263 Target::Method(MethodKind::Trait { body: true })
264 } else {
265 Target::Method(MethodKind::Inherent)
266 }
267 }
268 hir::ImplItemKind::Type(..) => Target::AssocTy,
269 }
270}
271
272struct CheckSpirvAttrVisitor<'tcx> {
273 tcx: TyCtxt<'tcx>,
274 sym: Rc<Symbols>,
275}
276
277impl CheckSpirvAttrVisitor<'_> {
278 fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
279 let mut aggregated_attrs = AggregatedSpirvAttributes::default();
280
281 let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
282
283 let attrs = self.tcx.hir_attrs(hir_id);
284 for parse_attr_result in parse_attrs(attrs) {
285 let (span, parsed_attr) = match parse_attr_result {
286 Ok(span_and_parsed_attr) => span_and_parsed_attr,
287 Err((span, msg)) => {
288 self.tcx.dcx().span_err(span, msg);
289 continue;
290 }
291 };
292
293 struct Expected<T>(T);
295
296 let valid_target = match parsed_attr {
297 SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
298 Target::Struct => {
299 Ok(())
302 }
303
304 _ => Err(Expected("struct")),
305 },
306
307 SpirvAttribute::Entry(_) => match target {
308 Target::Fn
309 | Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
310 Ok(())
313 }
314
315 _ => Err(Expected("function")),
316 },
317
318 SpirvAttribute::StorageClass(_)
319 | SpirvAttribute::Builtin(_)
320 | SpirvAttribute::DescriptorSet(_)
321 | SpirvAttribute::Binding(_)
322 | SpirvAttribute::Flat
323 | SpirvAttribute::Invariant
324 | SpirvAttribute::PerPrimitiveExt
325 | SpirvAttribute::InputAttachmentIndex(_)
326 | SpirvAttribute::SpecConstant(_) => match target {
327 Target::Param => {
328 let parent_hir_id = self.tcx.parent_hir_id(hir_id);
329 let parent_is_entry_point = parse_attrs(self.tcx.hir_attrs(parent_hir_id))
330 .filter_map(|r| r.ok())
331 .any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
332 if !parent_is_entry_point {
333 self.tcx.dcx().span_err(
334 span,
335 "attribute is only valid on a parameter of an entry-point function",
336 );
337 } else {
338 if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
341 let valid = match storage_class {
342 StorageClass::Input | StorageClass::Output => {
343 Err("is the default and should not be explicitly specified")
344 }
345
346 StorageClass::Private
347 | StorageClass::Function
348 | StorageClass::Generic => {
349 Err("can not be used as part of an entry's interface")
350 }
351
352 _ => Ok(()),
353 };
354
355 if let Err(msg) = valid {
356 self.tcx.dcx().span_err(
357 span,
358 format!("`{storage_class:?}` storage class {msg}"),
359 );
360 }
361 }
362 }
363 Ok(())
364 }
365
366 _ => Err(Expected("function parameter")),
367 },
368 SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
369 match target {
370 Target::Fn => Ok(()),
371 _ => Err(Expected("function")),
372 }
373 }
374 };
375 match valid_target {
376 Err(Expected(expected_target)) => {
377 self.tcx.dcx().span_err(
378 span,
379 format!(
380 "attribute is only valid on a {expected_target}, not on a {target}"
381 ),
382 );
383 }
384 Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
385 Ok(()) => {}
386 Err(MultipleAttrs {
387 prev_span,
388 category,
389 }) => {
390 self.tcx
391 .dcx()
392 .struct_span_err(
393 span,
394 format!("only one {category} attribute is allowed on a {target}"),
395 )
396 .with_span_note(prev_span, format!("previous {category} attribute"))
397 .emit();
398 }
399 },
400 }
401 }
402
403 if let Some(block_attr) = aggregated_attrs.block {
407 self.tcx.dcx().span_warn(
408 block_attr.span,
409 "#[spirv(block)] is no longer needed and should be removed",
410 );
411 }
412 }
413}
414
415impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
417 type NestedFilter = nested_filter::OnlyBodies;
418
419 fn maybe_tcx(&mut self) -> Self::MaybeTyCtxt {
420 self.tcx
421 }
422
423 fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
424 let target = Target::from_item(item);
425 self.check_spirv_attributes(item.hir_id(), target);
426 intravisit::walk_item(self, item);
427 }
428
429 fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
430 let target = Target::from_generic_param(generic_param);
431 self.check_spirv_attributes(generic_param.hir_id, target);
432 intravisit::walk_generic_param(self, generic_param);
433 }
434
435 fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
436 let target = Target::from_trait_item(trait_item);
437 self.check_spirv_attributes(trait_item.hir_id(), target);
438 intravisit::walk_trait_item(self, trait_item);
439 }
440
441 fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
442 self.check_spirv_attributes(field.hir_id, Target::Field);
443 intravisit::walk_field_def(self, field);
444 }
445
446 fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
447 self.check_spirv_attributes(arm.hir_id, Target::Arm);
448 intravisit::walk_arm(self, arm);
449 }
450
451 fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
452 let target = Target::from_foreign_item(f_item);
453 self.check_spirv_attributes(f_item.hir_id(), target);
454 intravisit::walk_foreign_item(self, f_item);
455 }
456
457 fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
458 let target = target_from_impl_item(self.tcx, impl_item);
459 self.check_spirv_attributes(impl_item.hir_id(), target);
460 intravisit::walk_impl_item(self, impl_item);
461 }
462
463 fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
464 if let hir::StmtKind::Let(l) = stmt.kind {
466 self.check_spirv_attributes(l.hir_id, Target::Statement);
467 }
468 intravisit::walk_stmt(self, stmt);
469 }
470
471 fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
472 let target = match expr.kind {
473 hir::ExprKind::Closure { .. } => Target::Closure,
474 _ => Target::Expression,
475 };
476
477 self.check_spirv_attributes(expr.hir_id, target);
478 intravisit::walk_expr(self, expr);
479 }
480
481 fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
482 self.check_spirv_attributes(variant.hir_id, Target::Variant);
483 intravisit::walk_variant(self, variant);
484 }
485
486 fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
487 self.check_spirv_attributes(param.hir_id, Target::Param);
488
489 intravisit::walk_param(self, param);
490 }
491}
492
493fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
495 let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
496 tcx,
497 sym: Symbols::get(),
498 };
499 tcx.hir_visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
500 if module_def_id.is_top_level_module() {
501 check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
502 }
503}
504
505pub(crate) fn provide(providers: &mut Providers) {
506 *providers = Providers {
507 check_mod_attrs: |tcx, module_def_id| {
508 (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
510 check_mod_attrs(tcx, module_def_id);
511 },
512 ..*providers
513 };
514}