use crate::codegen_cx::CodegenCx;
use crate::symbols::Symbols;
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
use rustc_ast::Attribute;
use rustc_hir as hir;
use rustc_hir::def_id::LocalModDefId;
use rustc_hir::intravisit::{self, Visitor};
use rustc_hir::{CRATE_HIR_ID, HirId, MethodKind, Target};
use rustc_middle::hir::nested_filter;
use rustc_middle::query::Providers;
use rustc_middle::ty::TyCtxt;
use rustc_span::{Span, Symbol};
use std::rc::Rc;
#[derive(Copy, Clone, Debug)]
pub struct ExecutionModeExtra {
args: [u32; 3],
len: u8,
}
impl ExecutionModeExtra {
pub(crate) fn new(args: impl AsRef<[u32]>) -> Self {
let _args = args.as_ref();
let mut args = [0; 3];
args[.._args.len()].copy_from_slice(_args);
let len = _args.len() as u8;
Self { args, len }
}
}
impl AsRef<[u32]> for ExecutionModeExtra {
fn as_ref(&self) -> &[u32] {
&self.args[..self.len as _]
}
}
#[derive(Clone, Debug)]
pub struct Entry {
pub execution_model: ExecutionModel,
pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>,
pub name: Option<Symbol>,
}
impl From<ExecutionModel> for Entry {
fn from(execution_model: ExecutionModel) -> Self {
Self {
execution_model,
execution_modes: Vec::new(),
name: None,
}
}
}
#[derive(Debug, Clone)]
pub enum IntrinsicType {
GenericImageType,
Sampler,
AccelerationStructureKhr,
SampledImage,
RayQueryKhr,
RuntimeArray,
TypedBuffer,
Matrix,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct SpecConstant {
pub id: u32,
pub default: Option<u32>,
}
#[derive(Debug, Clone)]
pub enum SpirvAttribute {
IntrinsicType(IntrinsicType),
Block,
Entry(Entry),
StorageClass(StorageClass),
Builtin(BuiltIn),
DescriptorSet(u32),
Binding(u32),
Flat,
PerPrimitiveExt,
Invariant,
InputAttachmentIndex(u32),
SpecConstant(SpecConstant),
BufferLoadIntrinsic,
BufferStoreIntrinsic,
}
#[derive(Copy, Clone)]
pub struct Spanned<T> {
pub value: T,
pub span: Span,
}
#[derive(Default)]
pub struct AggregatedSpirvAttributes {
pub intrinsic_type: Option<Spanned<IntrinsicType>>,
pub block: Option<Spanned<()>>,
pub entry: Option<Spanned<Entry>>,
pub storage_class: Option<Spanned<StorageClass>>,
pub builtin: Option<Spanned<BuiltIn>>,
pub descriptor_set: Option<Spanned<u32>>,
pub binding: Option<Spanned<u32>>,
pub flat: Option<Spanned<()>>,
pub invariant: Option<Spanned<()>>,
pub per_primitive_ext: Option<Spanned<()>>,
pub input_attachment_index: Option<Spanned<u32>>,
pub spec_constant: Option<Spanned<SpecConstant>>,
pub buffer_load_intrinsic: Option<Spanned<()>>,
pub buffer_store_intrinsic: Option<Spanned<()>>,
}
struct MultipleAttrs {
prev_span: Span,
category: &'static str,
}
impl AggregatedSpirvAttributes {
pub fn parse<'tcx>(cx: &CodegenCx<'tcx>, attrs: &'tcx [Attribute]) -> Self {
let mut aggregated_attrs = Self::default();
for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
cx.tcx.dcx().span_delayed_bug(span, msg);
continue;
}
};
match aggregated_attrs.try_insert_attr(parsed_attr, span) {
Ok(()) => {}
Err(MultipleAttrs {
prev_span: _,
category,
}) => {
cx.tcx
.dcx()
.span_delayed_bug(span, format!("multiple {category} attributes"));
}
}
}
aggregated_attrs
}
fn try_insert_attr(&mut self, attr: SpirvAttribute, span: Span) -> Result<(), MultipleAttrs> {
fn try_insert<T>(
slot: &mut Option<Spanned<T>>,
value: T,
span: Span,
category: &'static str,
) -> Result<(), MultipleAttrs> {
if let Some(prev) = slot {
Err(MultipleAttrs {
prev_span: prev.span,
category,
})
} else {
*slot = Some(Spanned { value, span });
Ok(())
}
}
use SpirvAttribute::*;
match attr {
IntrinsicType(value) => {
try_insert(&mut self.intrinsic_type, value, span, "intrinsic type")
}
Block => try_insert(&mut self.block, (), span, "#[spirv(block)]"),
Entry(value) => try_insert(&mut self.entry, value, span, "entry-point"),
StorageClass(value) => {
try_insert(&mut self.storage_class, value, span, "storage class")
}
Builtin(value) => try_insert(&mut self.builtin, value, span, "builtin"),
DescriptorSet(value) => try_insert(
&mut self.descriptor_set,
value,
span,
"#[spirv(descriptor_set)]",
),
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
PerPrimitiveExt => try_insert(
&mut self.per_primitive_ext,
(),
span,
"#[spirv(per_primitive_ext)]",
),
InputAttachmentIndex(value) => try_insert(
&mut self.input_attachment_index,
value,
span,
"#[spirv(attachment_index)]",
),
SpecConstant(value) => try_insert(
&mut self.spec_constant,
value,
span,
"#[spirv(spec_constant)]",
),
BufferLoadIntrinsic => try_insert(
&mut self.buffer_load_intrinsic,
(),
span,
"#[spirv(buffer_load_intrinsic)]",
),
BufferStoreIntrinsic => try_insert(
&mut self.buffer_store_intrinsic,
(),
span,
"#[spirv(buffer_store_intrinsic)]",
),
}
}
}
fn target_from_impl_item(tcx: TyCtxt<'_>, impl_item: &hir::ImplItem<'_>) -> Target {
match impl_item.kind {
hir::ImplItemKind::Const(..) => Target::AssocConst,
hir::ImplItemKind::Fn(..) => {
let parent_owner_id = tcx.hir().get_parent_item(impl_item.hir_id());
let containing_item = tcx.hir().expect_item(parent_owner_id.def_id);
let containing_impl_is_for_trait = match &containing_item.kind {
hir::ItemKind::Impl(hir::Impl { of_trait, .. }) => of_trait.is_some(),
_ => unreachable!("parent of an ImplItem must be an Impl"),
};
if containing_impl_is_for_trait {
Target::Method(MethodKind::Trait { body: true })
} else {
Target::Method(MethodKind::Inherent)
}
}
hir::ImplItemKind::Type(..) => Target::AssocTy,
}
}
struct CheckSpirvAttrVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
sym: Rc<Symbols>,
}
impl CheckSpirvAttrVisitor<'_> {
fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
let mut aggregated_attrs = AggregatedSpirvAttributes::default();
let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
let attrs = self.tcx.hir().attrs(hir_id);
for parse_attr_result in parse_attrs(attrs) {
let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
self.tcx.dcx().span_err(span, msg);
continue;
}
};
struct Expected<T>(T);
let valid_target = match parsed_attr {
SpirvAttribute::IntrinsicType(_) | SpirvAttribute::Block => match target {
Target::Struct => {
Ok(())
}
_ => Err(Expected("struct")),
},
SpirvAttribute::Entry(_) => match target {
Target::Fn
| Target::Method(MethodKind::Trait { body: true } | MethodKind::Inherent) => {
Ok(())
}
_ => Err(Expected("function")),
},
SpirvAttribute::StorageClass(_)
| SpirvAttribute::Builtin(_)
| SpirvAttribute::DescriptorSet(_)
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
| SpirvAttribute::PerPrimitiveExt
| SpirvAttribute::InputAttachmentIndex(_)
| SpirvAttribute::SpecConstant(_) => match target {
Target::Param => {
let parent_hir_id = self.tcx.parent_hir_id(hir_id);
let parent_is_entry_point =
parse_attrs(self.tcx.hir().attrs(parent_hir_id))
.filter_map(|r| r.ok())
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
if !parent_is_entry_point {
self.tcx.dcx().span_err(
span,
"attribute is only valid on a parameter of an entry-point function",
);
} else {
if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
let valid = match storage_class {
StorageClass::Input | StorageClass::Output => {
Err("is the default and should not be explicitly specified")
}
StorageClass::Private
| StorageClass::Function
| StorageClass::Generic => {
Err("can not be used as part of an entry's interface")
}
_ => Ok(()),
};
if let Err(msg) = valid {
self.tcx.dcx().span_err(
span,
format!("`{storage_class:?}` storage class {msg}"),
);
}
}
}
Ok(())
}
_ => Err(Expected("function parameter")),
},
SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
match target {
Target::Fn => Ok(()),
_ => Err(Expected("function")),
}
}
};
match valid_target {
Err(Expected(expected_target)) => {
self.tcx.dcx().span_err(
span,
format!(
"attribute is only valid on a {expected_target}, not on a {target}"
),
);
}
Ok(()) => match aggregated_attrs.try_insert_attr(parsed_attr, span) {
Ok(()) => {}
Err(MultipleAttrs {
prev_span,
category,
}) => {
self.tcx
.dcx()
.struct_span_err(
span,
format!("only one {category} attribute is allowed on a {target}"),
)
.with_span_note(prev_span, format!("previous {category} attribute"))
.emit();
}
},
}
}
if let Some(block_attr) = aggregated_attrs.block {
self.tcx.dcx().span_warn(
block_attr.span,
"#[spirv(block)] is no longer needed and should be removed",
);
}
}
}
impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
type NestedFilter = nested_filter::OnlyBodies;
fn nested_visit_map(&mut self) -> Self::Map {
self.tcx.hir()
}
fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
let target = Target::from_item(item);
self.check_spirv_attributes(item.hir_id(), target);
intravisit::walk_item(self, item);
}
fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
let target = Target::from_generic_param(generic_param);
self.check_spirv_attributes(generic_param.hir_id, target);
intravisit::walk_generic_param(self, generic_param);
}
fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
let target = Target::from_trait_item(trait_item);
self.check_spirv_attributes(trait_item.hir_id(), target);
intravisit::walk_trait_item(self, trait_item);
}
fn visit_field_def(&mut self, field: &'tcx hir::FieldDef<'tcx>) {
self.check_spirv_attributes(field.hir_id, Target::Field);
intravisit::walk_field_def(self, field);
}
fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
self.check_spirv_attributes(arm.hir_id, Target::Arm);
intravisit::walk_arm(self, arm);
}
fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
let target = Target::from_foreign_item(f_item);
self.check_spirv_attributes(f_item.hir_id(), target);
intravisit::walk_foreign_item(self, f_item);
}
fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
let target = target_from_impl_item(self.tcx, impl_item);
self.check_spirv_attributes(impl_item.hir_id(), target);
intravisit::walk_impl_item(self, impl_item);
}
fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
if let hir::StmtKind::Let(l) = stmt.kind {
self.check_spirv_attributes(l.hir_id, Target::Statement);
}
intravisit::walk_stmt(self, stmt);
}
fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
let target = match expr.kind {
hir::ExprKind::Closure { .. } => Target::Closure,
_ => Target::Expression,
};
self.check_spirv_attributes(expr.hir_id, target);
intravisit::walk_expr(self, expr);
}
fn visit_variant(&mut self, variant: &'tcx hir::Variant<'tcx>) {
self.check_spirv_attributes(variant.hir_id, Target::Variant);
intravisit::walk_variant(self, variant);
}
fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
self.check_spirv_attributes(param.hir_id, Target::Param);
intravisit::walk_param(self, param);
}
}
fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalModDefId) {
let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
tcx,
sym: Symbols::get(),
};
tcx.hir()
.visit_item_likes_in_module(module_def_id, check_spirv_attr_visitor);
if module_def_id.is_top_level_module() {
check_spirv_attr_visitor.check_spirv_attributes(CRATE_HIR_ID, Target::Mod);
}
}
pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers {
check_mod_attrs: |tcx, module_def_id| {
(rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, module_def_id);
check_mod_attrs(tcx, module_def_id);
},
..*providers
};
}