1use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11 Align, BackendRepr, FieldIdx, FieldsShape, Primitive, Scalar, Size, VariantIdx, Variants,
12};
13use rustc_data_structures::fx::FxHashMap;
14use rustc_errors::ErrorGuaranteed;
15use rustc_index::Idx;
16use rustc_middle::query::Providers;
17use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
18use rustc_middle::ty::{
19 self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
20 TyKind, UintTy,
21};
22use rustc_middle::ty::{GenericArgsRef, ScalarInt};
23use rustc_middle::{bug, span_bug};
24use rustc_span::DUMMY_SP;
25use rustc_span::def_id::DefId;
26use rustc_span::{Span, Symbol};
27use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
28use std::cell::RefCell;
29use std::collections::hash_map::Entry;
30use std::fmt;
31
32pub(crate) fn provide(providers: &mut Providers) {
33 providers.fn_sig = |tcx, def_id| {
43 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_sig)(tcx, def_id);
46 result.map_bound(|outer| {
47 outer.map_bound(|mut inner| {
48 if let Abi::C { .. } = inner.abi {
49 inner.abi = Abi::Unadjusted;
50 }
51 inner
52 })
53 })
54 };
55
56 fn readjust_fn_abi<'tcx>(
61 tcx: TyCtxt<'tcx>,
62 fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
63 ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
64 let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
65 let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _, _| ArgAttributes::new());
66 arg.make_direct_deprecated();
69
70 if let PassMode::Pair(..) = arg.mode {
73 if let TyKind::Adt(..) = arg.layout.ty.kind() {
75 arg.mode = PassMode::Direct(ArgAttributes::new());
76 }
77 }
78
79 if arg.layout.is_zst() {
81 arg.mode = PassMode::Ignore;
82 }
83
84 arg
85 };
86 tcx.arena.alloc(FnAbi {
87 args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
88 ret: readjust_arg_abi(&fn_abi.ret),
89
90 c_variadic: fn_abi.c_variadic,
94 fixed_count: fn_abi.fixed_count,
95 conv: fn_abi.conv,
96 can_unwind: fn_abi.can_unwind,
97 })
98 }
99 providers.fn_abi_of_fn_ptr = |tcx, key| {
100 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_fn_ptr)(tcx, key);
101 Ok(readjust_fn_abi(tcx, result?))
102 };
103 providers.fn_abi_of_instance = |tcx, key| {
104 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_instance)(tcx, key);
105 Ok(readjust_fn_abi(tcx, result?))
106 };
107
108 providers.check_mono_item = |_, _| {};
115}
116
117#[derive(Default)]
123pub struct RecursivePointeeCache<'tcx> {
124 map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
125}
126
127impl<'tcx> RecursivePointeeCache<'tcx> {
128 fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
129 match self.map.borrow_mut().entry(pointee) {
130 Entry::Vacant(entry) => {
133 entry.insert(PointeeDefState::Defining);
134 None
135 }
136 Entry::Occupied(mut entry) => match *entry.get() {
137 PointeeDefState::Defining => {
141 let new_id = cx.emit_global().id();
142 cx.emit_global()
146 .type_forward_pointer(new_id, StorageClass::Generic);
147 entry.insert(PointeeDefState::DefiningWithForward(new_id));
148 cx.zombie_with_span(
149 new_id,
150 span,
151 "cannot create self-referential types, even through pointers",
152 );
153 Some(new_id)
154 }
155 PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
159 },
160 }
161 }
162
163 fn end(
164 &self,
165 cx: &CodegenCx<'tcx>,
166 span: Span,
167 pointee: PointeeTy<'tcx>,
168 pointee_spv: Word,
169 ) -> Word {
170 match self.map.borrow_mut().entry(pointee) {
171 Entry::Vacant(_) => {
173 span_bug!(span, "RecursivePointeeCache::end should always have entry")
174 }
175 Entry::Occupied(mut entry) => match *entry.get() {
176 PointeeDefState::Defining => {
179 let id = SpirvType::Pointer {
180 pointee: pointee_spv,
181 }
182 .def(span, cx);
183 entry.insert(PointeeDefState::Defined(id));
184 id
185 }
186 PointeeDefState::DefiningWithForward(id) => {
189 entry.insert(PointeeDefState::Defined(id));
190 SpirvType::Pointer {
191 pointee: pointee_spv,
192 }
193 .def_with_id(cx, span, id)
194 }
195 PointeeDefState::Defined(_) => {
196 span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
197 }
198 },
199 }
200 }
201}
202
203#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
204enum PointeeTy<'tcx> {
205 Ty(TyAndLayout<'tcx>),
206 Fn(PolyFnSig<'tcx>),
207}
208
209impl fmt::Display for PointeeTy<'_> {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 match self {
212 PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
213 PointeeTy::Fn(ty) => write!(f, "{ty}"),
214 }
215 }
216}
217
218enum PointeeDefState {
219 Defining,
220 DefiningWithForward(Word),
221 Defined(Word),
222}
223
224pub trait ConvSpirvType<'tcx> {
227 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
228}
229
230impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
231 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
232 match *self {
233 PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
234 PointeeTy::Fn(ty) => cx
235 .fn_abi_of_fn_ptr(ty, ty::List::empty())
236 .spirv_type(span, cx),
237 }
238 }
239}
240
241impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
242 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
243 let mut argument_types = Vec::new();
245
246 let return_type = match self.ret.mode {
247 PassMode::Ignore => SpirvType::Void.def(span, cx),
248 PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
249 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
250 span,
251 "query hooks should've made this `PassMode` impossible: {:#?}",
252 self.ret
253 ),
254 };
255
256 for arg in self.args.iter() {
257 let arg_type = match arg.mode {
258 PassMode::Ignore => continue,
259 PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
260 PassMode::Pair(_, _) => {
261 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
262 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
263 continue;
264 }
265 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
266 span,
267 "query hooks should've made this `PassMode` impossible: {:#?}",
268 arg
269 ),
270 };
271 argument_types.push(arg_type);
272 }
273
274 SpirvType::Function {
275 return_type,
276 arguments: &argument_types,
277 }
278 .def(span, cx)
279 }
280}
281
282impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
283 fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
284 if let TyKind::Adt(adt, args) = *self.ty.kind() {
285 if span == DUMMY_SP {
286 span = cx.tcx.def_span(adt.did());
287 }
288
289 let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
290
291 if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
292 && let Ok(spirv_type) =
293 trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
294 {
295 return spirv_type;
296 }
297 }
298
299 match self.backend_repr {
303 _ if self.uninhabited => SpirvType::Adt {
304 def_id: def_id_for_spirv_type_adt(*self),
305 size: Some(Size::ZERO),
306 align: Align::from_bytes(0).unwrap(),
307 field_types: &[],
308 field_offsets: &[],
309 field_names: None,
310 }
311 .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
312 BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
313 BackendRepr::ScalarPair(a, b) => {
314 let mut non_zst_fields = (0..self.fields.count())
330 .map(|i| (i, self.field(cx, i)))
331 .filter(|(_, field)| !field.is_zst());
332 let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
333 (Some(field), None) => Some(field),
334 _ => None,
335 };
336 if let Some((i, field)) = sole_non_zst_field {
337 if self.fields.offset(i) == Size::ZERO
340 && field.size == self.size
341 && field.align.abi == self.align.abi
342 && field.backend_repr.eq_up_to_validity(&self.backend_repr)
343 {
344 return field.spirv_type(span, cx);
345 }
346 }
347
348 let a_offset = Size::ZERO;
351 let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
352 let a = trans_scalar(cx, span, *self, a, a_offset);
353 let b = trans_scalar(cx, span, *self, b, b_offset);
354 let size = if self.is_unsized() {
355 None
356 } else {
357 Some(self.size)
358 };
359 let mut field_names = Vec::new();
361 if let TyKind::Adt(adt, _) = self.ty.kind()
362 && let Variants::Single { index } = self.variants
363 {
364 for i in self.fields.index_by_increasing_offset() {
365 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
366 field_names.push(field.name);
367 }
368 }
369 SpirvType::Adt {
370 def_id: def_id_for_spirv_type_adt(*self),
371 size,
372 align: self.align.abi,
373 field_types: &[a, b],
374 field_offsets: &[a_offset, b_offset],
375 field_names: if field_names.len() == 2 {
376 Some(&field_names)
377 } else {
378 None
379 },
380 }
381 .def_with_name(cx, span, TyLayoutNameKey::from(*self))
382 }
383 BackendRepr::SimdVector { element, count } => {
384 let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
385 SpirvType::Vector {
386 element: elem_spirv,
387 count: count as u32,
388 size: self.size,
389 align: self.align.abi,
390 }
391 .def(span, cx)
392 }
393 BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
394 }
395 }
396}
397
398pub fn scalar_pair_element_backend_type<'tcx>(
401 cx: &CodegenCx<'tcx>,
402 span: Span,
403 ty: TyAndLayout<'tcx>,
404 index: usize,
405) -> Word {
406 let [a, b] = match ty.layout.backend_repr() {
407 BackendRepr::ScalarPair(a, b) => [a, b],
408 other => span_bug!(
409 span,
410 "scalar_pair_element_backend_type invalid abi: {:?}",
411 other
412 ),
413 };
414 let offset = match index {
415 0 => Size::ZERO,
416 1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
417 _ => unreachable!(),
418 };
419 trans_scalar(cx, span, ty, [a, b][index], offset)
420}
421
422fn trans_scalar<'tcx>(
430 cx: &CodegenCx<'tcx>,
431 span: Span,
432 ty: TyAndLayout<'tcx>,
433 scalar: Scalar,
434 offset: Size,
435) -> Word {
436 if scalar.is_bool() {
437 return SpirvType::Bool.def(span, cx);
438 }
439
440 match scalar.primitive() {
441 Primitive::Int(int_kind, signedness) => {
442 SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
443 }
444 Primitive::Float(float_kind) => {
445 SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
446 }
447 Primitive::Pointer(_) => {
448 let pointee_ty = dig_scalar_pointee(cx, ty, offset);
449 if let Some(predefined_result) = cx
452 .type_cache
453 .recursive_pointee_cache
454 .begin(cx, span, pointee_ty)
455 {
456 predefined_result
457 } else {
458 let pointee = pointee_ty.spirv_type(span, cx);
459 cx.type_cache
460 .recursive_pointee_cache
461 .end(cx, span, pointee_ty, pointee)
462 }
463 }
464 }
465}
466
467fn dig_scalar_pointee<'tcx>(
478 cx: &CodegenCx<'tcx>,
479 layout: TyAndLayout<'tcx>,
480 offset: Size,
481) -> PointeeTy<'tcx> {
482 if let FieldsShape::Primitive = layout.fields {
483 assert_eq!(offset, Size::ZERO);
484 let pointee = match *layout.ty.kind() {
485 TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
486 PointeeTy::Ty(cx.layout_of(pointee_ty))
487 }
488 TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
489 _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
490 };
491 return pointee;
492 }
493
494 let all_fields = (match &layout.variants {
495 Variants::Empty => 0..0,
496 Variants::Multiple { variants, .. } => 0..variants.len(),
497 Variants::Single { index } => {
498 let i = index.as_usize();
499 i..i + 1
500 }
501 })
502 .flat_map(|variant_idx| {
503 let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
504 (0..variant.fields.count()).map(move |field_idx| {
505 (
506 variant.field(cx, field_idx),
507 variant.fields.offset(field_idx),
508 )
509 })
510 });
511
512 let mut pointee = None;
513 for (field, field_offset) in all_fields {
514 if field.is_zst() {
515 continue;
516 }
517 if (field_offset..field_offset + field.size).contains(&offset) {
518 let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
519 match pointee {
520 Some(old_pointee) if old_pointee != new_pointee => {
521 cx.tcx.dcx().fatal(format!(
522 "dig_scalar_pointee: unsupported Pointer with different \
523 pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
524 ));
525 }
526 _ => pointee = Some(new_pointee),
527 }
528 }
529 }
530 pointee.unwrap_or_else(|| {
531 bug!(
532 "field containing Pointer scalar at offset {:?} not found in {:#?}",
533 offset,
534 layout
535 )
536 })
537}
538
539fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
542 fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
543 assert_eq!(ty.size, Size::ZERO);
544 SpirvType::Adt {
545 def_id: def_id_for_spirv_type_adt(ty),
546 size: Some(Size::ZERO),
547 align: ty.align.abi,
548 field_types: &[],
549 field_offsets: &[],
550 field_names: None,
551 }
552 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
553 }
554 match ty.fields {
555 FieldsShape::Primitive => span_bug!(
556 span,
557 "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
558 ty
559 ),
560 FieldsShape::Union(_) => {
561 assert!(!ty.is_unsized(), "{ty:#?}");
562
563 let largest_case = (0..ty.fields.count())
569 .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
570 .max_by_key(|(_, case)| case.size);
571
572 if let Some((case_idx, case)) = largest_case {
573 if ty.align != case.align {
574 trans_struct_or_union(cx, span, ty, Some(case_idx))
576 } else {
577 assert_eq!(ty.size, case.size);
578 case.spirv_type(span, cx)
579 }
580 } else {
581 create_zst(cx, span, ty)
582 }
583 }
584 FieldsShape::Array { stride, count } => {
585 let element_type = ty.field(cx, 0).spirv_type(span, cx);
586 if ty.is_unsized() {
587 assert_eq!(count, 0);
590 SpirvType::RuntimeArray {
591 element: element_type,
592 }
593 .def(span, cx)
594 } else if count == 0 {
595 create_zst(cx, span, ty)
597 } else {
598 let count_const = cx.constant_u32(span, count as u32);
599 let element_spv = cx.lookup_type(element_type);
600 let stride_spv = element_spv
601 .sizeof(cx)
602 .expect("Unexpected unsized type in sized FieldsShape::Array")
603 .align_to(element_spv.alignof(cx));
604 assert_eq!(stride_spv, stride);
605 SpirvType::Array {
606 element: element_type,
607 count: count_const,
608 }
609 .def(span, cx)
610 }
611 }
612 FieldsShape::Arbitrary {
613 offsets: _,
614 memory_index: _,
615 } => trans_struct_or_union(cx, span, ty, None),
616 }
617}
618
619#[cfg_attr(
620 not(rustc_codegen_spirv_disable_pqp_cg_ssa),
621 expect(
622 unused,
623 reason = "actually used from \
624 `<rustc_codegen_ssa::traits::ConstCodegenMethods for CodegenCx<'_>>::const_struct`, \
625 but `rustc_codegen_ssa` being `pqp_cg_ssa` makes that trait unexported"
626 )
627)]
628pub fn auto_struct_layout(
630 cx: &CodegenCx<'_>,
631 field_types: &[Word],
632) -> (Vec<Size>, Option<Size>, Align) {
633 let mut field_offsets = Vec::with_capacity(field_types.len());
635 let mut offset = Some(Size::ZERO);
636 let mut max_align = Align::from_bytes(0).unwrap();
637 for &field_type in field_types {
638 let spirv_type = cx.lookup_type(field_type);
639 let field_size = spirv_type.sizeof(cx);
640 let field_align = spirv_type.alignof(cx);
641 let this_offset = offset
642 .expect("Unsized values can only be the last field in a struct")
643 .align_to(field_align);
644
645 field_offsets.push(this_offset);
646 if field_align > max_align {
647 max_align = field_align;
648 }
649 offset = field_size.map(|size| this_offset + size);
650 }
651 (field_offsets, offset, max_align)
652}
653
654fn trans_struct_or_union<'tcx>(
656 cx: &CodegenCx<'tcx>,
657 span: Span,
658 ty: TyAndLayout<'tcx>,
659 union_case: Option<FieldIdx>,
660) -> Word {
661 let size = if ty.is_unsized() { None } else { Some(ty.size) };
662 let align = ty.align.abi;
663 let mut field_types = Vec::new();
665 let mut field_offsets = Vec::new();
666 let mut field_names = Vec::new();
667 for i in ty.fields.index_by_increasing_offset() {
668 if let Some(expected_field_idx) = union_case
669 && i != expected_field_idx.as_usize()
670 {
671 continue;
672 }
673
674 let field_ty = ty.field(cx, i);
675 field_types.push(field_ty.spirv_type(span, cx));
676 let offset = ty.fields.offset(i);
677 field_offsets.push(offset);
678 if let Variants::Single { index } = ty.variants {
679 if let TyKind::Adt(adt, _) = ty.ty.kind() {
680 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
681 field_names.push(field.name);
682 } else {
683 field_names.push(Symbol::intern(&format!("{i}")));
685 }
686 } else {
687 if let TyKind::Adt(_, _) = ty.ty.kind() {
688 } else {
689 span_bug!(span, "Variants::Multiple not TyKind::Adt");
690 }
691 if i == 0 {
692 field_names.push(cx.sym.discriminant);
693 } else {
694 cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
695 }
696 };
697 }
698 SpirvType::Adt {
699 def_id: def_id_for_spirv_type_adt(ty),
700 size,
701 align,
702 field_types: &field_types,
703 field_offsets: &field_offsets,
704 field_names: Some(&field_names),
705 }
706 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
707}
708
709fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
713 match *layout.ty.kind() {
714 TyKind::Adt(def, _) => Some(def.did()),
715 TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
716 Some(def_id)
717 }
718 _ => None,
719 }
720}
721
722fn span_for_spirv_type_adt(cx: &CodegenCx<'_>, layout: TyAndLayout<'_>) -> Option<Span> {
723 def_id_for_spirv_type_adt(layout).map(|did| cx.tcx.def_span(did))
724}
725
726#[derive(Copy, Clone, PartialEq, Eq, Hash)]
729pub struct TyLayoutNameKey<'tcx> {
730 ty: Ty<'tcx>,
731 variant: Option<VariantIdx>,
732}
733
734impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
735 fn from(layout: TyAndLayout<'tcx>) -> Self {
736 TyLayoutNameKey {
737 ty: layout.ty,
738 variant: match layout.variants {
739 Variants::Single { index } => Some(index),
740 _ => None,
741 },
742 }
743 }
744}
745
746impl fmt::Display for TyLayoutNameKey<'_> {
747 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
748 write!(f, "{}", self.ty)?;
749 if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
750 && def.is_enum()
751 && !def.variants().is_empty()
752 {
753 write!(f, "::{}", def.variants()[index].name)?;
754 }
755 if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
756 write!(f, "::{}", CoroutineArgs::variant_name(index))?;
757 }
758 Ok(())
759 }
760}
761
762fn trans_intrinsic_type<'tcx>(
763 cx: &CodegenCx<'tcx>,
764 span: Span,
765 ty: TyAndLayout<'tcx>,
766 args: GenericArgsRef<'tcx>,
767 intrinsic_type_attr: IntrinsicType,
768) -> Result<Word, ErrorGuaranteed> {
769 match intrinsic_type_attr {
770 IntrinsicType::GenericImageType => {
771 if ty.size != Size::from_bytes(4) {
773 return Err(cx
774 .tcx
775 .dcx()
776 .err("#[spirv(generic_image)] type must have size 4"));
777 }
778
779 let sampled_type = match args.type_at(0).kind() {
792 TyKind::Int(int) => match int {
793 IntTy::Isize => {
794 SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, true)
795 .def(span, cx)
796 }
797 IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
798 IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
799 IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
800 IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
801 IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
802 },
803 TyKind::Uint(uint) => match uint {
804 UintTy::Usize => {
805 SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, false)
806 .def(span, cx)
807 }
808 UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
809 UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
810 UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
811 UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
812 UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
813 },
814 TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
815 TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
816 _ => {
817 return Err(cx
818 .tcx
819 .dcx()
820 .span_err(span, "Invalid sampled type to `Image`."));
821 }
822 };
823
824 trait FromScalarInt: Sized {
833 fn from_scalar_int(n: ScalarInt) -> Option<Self>;
834 }
835
836 impl FromScalarInt for u32 {
837 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
838 Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
839 }
840 }
841
842 impl FromScalarInt for Dim {
843 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
844 Dim::from_u32(u32::from_scalar_int(n)?)
845 }
846 }
847
848 impl FromScalarInt for ImageFormat {
849 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
850 ImageFormat::from_u32(u32::from_scalar_int(n)?)
851 }
852 }
853
854 fn const_int_value<'tcx, P: FromScalarInt>(
855 cx: &CodegenCx<'tcx>,
856 const_: Const<'tcx>,
857 ) -> Result<P, ErrorGuaranteed> {
858 let ty::Value {
859 ty: const_ty,
860 valtree: const_val,
861 } = const_.to_value();
862 assert!(const_ty.is_integral());
863 const_val
864 .try_to_scalar_int()
865 .and_then(P::from_scalar_int)
866 .ok_or_else(|| {
867 cx.tcx
868 .dcx()
869 .err(format!("invalid value for Image const generic: {const_}"))
870 })
871 }
872
873 let dim = const_int_value(cx, args.const_at(1))?;
874 let depth = const_int_value(cx, args.const_at(2))?;
875 let arrayed = const_int_value(cx, args.const_at(3))?;
876 let multisampled = const_int_value(cx, args.const_at(4))?;
877 let sampled = const_int_value(cx, args.const_at(5))?;
878 let image_format = const_int_value(cx, args.const_at(6))?;
879
880 let ty = SpirvType::Image {
881 sampled_type,
882 dim,
883 depth,
884 arrayed,
885 multisampled,
886 sampled,
887 image_format,
888 };
889 Ok(ty.def(span, cx))
890 }
891 IntrinsicType::Sampler => {
892 if ty.size != Size::from_bytes(4) {
894 return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
895 }
896 Ok(SpirvType::Sampler.def(span, cx))
897 }
898 IntrinsicType::AccelerationStructureKhr => {
899 Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
900 }
901 IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
902 IntrinsicType::SampledImage => {
903 if ty.size != Size::from_bytes(4) {
905 return Err(cx
906 .tcx
907 .dcx()
908 .err("#[spirv(sampled_image)] type must have size 4"));
909 }
910
911 if let Some(image_ty) = args.types().next() {
914 let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
916 Ok(SpirvType::SampledImage { image_type }.def(span, cx))
917 } else {
918 Err(cx
919 .tcx
920 .dcx()
921 .err("#[spirv(sampled_image)] type must have a generic image type"))
922 }
923 }
924 IntrinsicType::RuntimeArray => {
925 if ty.size != Size::from_bytes(4) {
926 return Err(cx
927 .tcx
928 .dcx()
929 .err("#[spirv(runtime_array)] type must have size 4"));
930 }
931
932 if let Some(elem_ty) = args.types().next() {
935 Ok(SpirvType::RuntimeArray {
936 element: cx.layout_of(elem_ty).spirv_type(span, cx),
937 }
938 .def(span, cx))
939 } else {
940 Err(cx
941 .tcx
942 .dcx()
943 .err("#[spirv(runtime_array)] type must have a generic element type"))
944 }
945 }
946 IntrinsicType::TypedBuffer => {
947 if ty.size != Size::from_bytes(4) {
948 return Err(cx
949 .tcx
950 .sess
951 .dcx()
952 .err("#[spirv(typed_buffer)] type must have size 4"));
953 }
954
955 if let Some(data_ty) = args.types().next() {
958 Ok(SpirvType::InterfaceBlock {
963 inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
964 }
965 .def(span, cx))
966 } else {
967 Err(cx
968 .tcx
969 .sess
970 .dcx()
971 .err("#[spirv(typed_buffer)] type must have a generic data type"))
972 }
973 }
974 IntrinsicType::Matrix => {
975 let span = span_for_spirv_type_adt(cx, ty).unwrap();
976 let err_attr_name = "`#[spirv(matrix)]`";
977 let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
978 match cx.lookup_type(element) {
979 SpirvType::Vector { .. } => (),
980 ty => {
981 return Err(cx
982 .tcx
983 .dcx()
984 .struct_span_err(
985 span,
986 format!("{err_attr_name} type fields must all be vectors"),
987 )
988 .with_note(format!("field type is {}", ty.debug(element, cx)))
989 .emit());
990 }
991 }
992 Ok(SpirvType::Matrix { element, count }.def(span, cx))
993 }
994 IntrinsicType::Vector => {
995 let span = span_for_spirv_type_adt(cx, ty).unwrap();
996 let err_attr_name = "`#[spirv(vector)]`";
997 let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
998 match cx.lookup_type(element) {
999 SpirvType::Bool | SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1000 ty => {
1001 return Err(cx
1002 .tcx
1003 .dcx()
1004 .struct_span_err(
1005 span,
1006 format!(
1007 "{err_attr_name} type fields must all be floats, integers or bools"
1008 ),
1009 )
1010 .with_note(format!("field type is {}", ty.debug(element, cx)))
1011 .emit());
1012 }
1013 }
1014 Ok(SpirvType::Vector {
1015 element,
1016 count,
1017 size: ty.size,
1018 align: ty.align.abi,
1019 }
1020 .def(span, cx))
1021 }
1022 }
1023}
1024
1025fn trans_glam_like_struct<'tcx>(
1028 cx: &CodegenCx<'tcx>,
1029 span: Span,
1030 ty: TyAndLayout<'tcx>,
1031 args: GenericArgsRef<'tcx>,
1032 err_attr_name: &str,
1033) -> Result<(Word, u32), ErrorGuaranteed> {
1034 let tcx = cx.tcx;
1035 if let Some(adt) = ty.ty.ty_adt_def()
1036 && adt.is_struct()
1037 {
1038 let (count, element) = adt
1039 .non_enum_variant()
1040 .fields
1041 .iter()
1042 .map(|f| f.ty(tcx, args))
1043 .dedup_with_count()
1044 .exactly_one()
1045 .map_err(|_e| {
1046 tcx.dcx().span_err(
1047 span,
1048 format!("{err_attr_name} member types must all be the same"),
1049 )
1050 })?;
1051
1052 let element = cx.layout_of(element);
1053 let element_word = element.spirv_type(span, cx);
1054 let count = u32::try_from(count)
1055 .ok()
1056 .filter(|count| 2 <= *count && *count <= 4)
1057 .ok_or_else(|| {
1058 tcx.dcx()
1059 .span_err(span, format!("{err_attr_name} must have 2, 3 or 4 members"))
1060 })?;
1061
1062 for i in 0..ty.fields.count() {
1063 let expected = element.size.checked_mul(i as u64, cx).unwrap();
1064 let actual = ty.fields.offset(i);
1065 if actual != expected {
1066 let name: &str = adt
1067 .non_enum_variant()
1068 .fields
1069 .get(FieldIdx::from(i))
1070 .unwrap()
1071 .name
1072 .as_str();
1073 tcx.dcx().span_fatal(
1074 span,
1075 format!(
1076 "Unexpected layout for {err_attr_name} annotated struct: \
1077 Expected member `{name}` at offset {expected:?}, but was at {actual:?}"
1078 ),
1079 )
1080 }
1081 }
1082
1083 Ok((element_word, count))
1084 } else {
1085 Err(tcx
1086 .dcx()
1087 .span_err(span, format!("{err_attr_name} type must be a struct")))
1088 }
1089}