1use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11 Align, BackendRepr, FieldIdx, FieldsShape, Primitive, Scalar, Size, VariantIdx, Variants,
12};
13use rustc_data_structures::fx::FxHashMap;
14use rustc_errors::ErrorGuaranteed;
15use rustc_index::Idx;
16use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
17use rustc_middle::ty::{
18 self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
19 TyKind, UintTy, ValTreeKindExt,
20};
21use rustc_middle::ty::{GenericArgsRef, ScalarInt};
22use rustc_middle::util::Providers;
23use rustc_middle::{bug, span_bug};
24use rustc_session::config::OptLevel;
25use rustc_span::DUMMY_SP;
26use rustc_span::def_id::DefId;
27use rustc_span::{Span, Symbol};
28use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
29use std::cell::RefCell;
30use std::collections::hash_map::Entry;
31use std::fmt;
32
33fn rewrite_c_abi_to_rust<'tcx>(
34 fn_sig: ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>>,
35) -> ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>> {
36 fn_sig.map_bound(|outer| {
37 outer.map_bound(|mut inner| {
38 if let Abi::C { .. } = inner.abi {
39 inner.abi = Abi::Rust;
40 }
41 inner
42 })
43 })
44}
45
46pub(crate) fn provide(providers: &mut Providers) {
47 providers.queries.fn_sig = |tcx, def_id| {
62 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.queries.fn_sig)(tcx, def_id);
65 rewrite_c_abi_to_rust(result)
66 };
67 providers.extern_queries.fn_sig = |tcx, def_id| {
68 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
71 .extern_queries
72 .fn_sig)(tcx, def_id);
73 rewrite_c_abi_to_rust(result)
74 };
75
76 fn readjust_fn_abi<'tcx>(
81 tcx: TyCtxt<'tcx>,
82 fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
83 ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
84 let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
85 let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _| ArgAttributes::new());
86 arg.make_direct_deprecated();
89
90 if let PassMode::Pair(..) = arg.mode {
93 if let TyKind::Adt(..) = arg.layout.ty.kind() {
95 arg.mode = PassMode::Direct(ArgAttributes::new());
96 }
97 }
98
99 if arg.layout.is_zst() {
101 arg.mode = PassMode::Ignore;
102 }
103
104 arg
105 };
106 tcx.arena.alloc(FnAbi {
107 args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
108 ret: readjust_arg_abi(&fn_abi.ret),
109
110 c_variadic: fn_abi.c_variadic,
114 fixed_count: fn_abi.fixed_count,
115 conv: fn_abi.conv,
116 can_unwind: fn_abi.can_unwind,
117 })
118 }
119 providers.queries.fn_abi_of_fn_ptr = |tcx, key| {
120 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
121 .queries
122 .fn_abi_of_fn_ptr)(tcx, key);
123 Ok(readjust_fn_abi(tcx, result?))
124 };
125 providers.queries.fn_abi_of_instance_no_deduced_attrs = |tcx, key| {
126 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
127 .queries
128 .fn_abi_of_instance_no_deduced_attrs)(tcx, key);
129 if tcx.sess.opts.optimize != OptLevel::No && tcx.sess.opts.incremental.is_none() {
134 result
135 } else {
136 Ok(readjust_fn_abi(tcx, result?))
137 }
138 };
139 providers.queries.fn_abi_of_instance_raw = |tcx, key| {
140 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
141 .queries
142 .fn_abi_of_instance_raw)(tcx, key);
143 Ok(readjust_fn_abi(tcx, result?))
144 };
145
146 providers.queries.check_mono_item = |_, _| {};
153}
154
155#[derive(Default)]
161pub struct RecursivePointeeCache<'tcx> {
162 map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
163}
164
165impl<'tcx> RecursivePointeeCache<'tcx> {
166 fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
167 match self.map.borrow_mut().entry(pointee) {
168 Entry::Vacant(entry) => {
171 entry.insert(PointeeDefState::Defining);
172 None
173 }
174 Entry::Occupied(mut entry) => match *entry.get() {
175 PointeeDefState::Defining => {
179 let new_id = cx.emit_global().id();
180 cx.emit_global()
184 .type_forward_pointer(new_id, StorageClass::Generic);
185 entry.insert(PointeeDefState::DefiningWithForward(new_id));
186 cx.zombie_with_span(
187 new_id,
188 span,
189 "cannot create self-referential types, even through pointers",
190 );
191 Some(new_id)
192 }
193 PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
197 },
198 }
199 }
200
201 fn end(
202 &self,
203 cx: &CodegenCx<'tcx>,
204 span: Span,
205 pointee: PointeeTy<'tcx>,
206 pointee_spv: Word,
207 ) -> Word {
208 match self.map.borrow_mut().entry(pointee) {
209 Entry::Vacant(_) => {
211 span_bug!(span, "RecursivePointeeCache::end should always have entry")
212 }
213 Entry::Occupied(mut entry) => match *entry.get() {
214 PointeeDefState::Defining => {
217 let id = SpirvType::Pointer {
218 pointee: pointee_spv,
219 }
220 .def(span, cx);
221 entry.insert(PointeeDefState::Defined(id));
222 id
223 }
224 PointeeDefState::DefiningWithForward(id) => {
227 entry.insert(PointeeDefState::Defined(id));
228 SpirvType::Pointer {
229 pointee: pointee_spv,
230 }
231 .def_with_id(cx, span, id)
232 }
233 PointeeDefState::Defined(_) => {
234 span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
235 }
236 },
237 }
238 }
239}
240
241#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
242enum PointeeTy<'tcx> {
243 Ty(TyAndLayout<'tcx>),
244 Fn(PolyFnSig<'tcx>),
245}
246
247impl fmt::Display for PointeeTy<'_> {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 match self {
250 PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
251 PointeeTy::Fn(ty) => write!(f, "{ty}"),
252 }
253 }
254}
255
256enum PointeeDefState {
257 Defining,
258 DefiningWithForward(Word),
259 Defined(Word),
260}
261
262pub trait ConvSpirvType<'tcx> {
265 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
266}
267
268impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
269 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
270 match *self {
271 PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
272 PointeeTy::Fn(ty) => cx
273 .fn_abi_of_fn_ptr(ty, ty::List::empty())
274 .spirv_type(span, cx),
275 }
276 }
277}
278
279impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
280 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
281 let mut argument_types = Vec::new();
283
284 let return_type = match self.ret.mode {
285 PassMode::Ignore => SpirvType::Void.def(span, cx),
286 PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
287 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
288 span,
289 "query hooks should've made this `PassMode` impossible: {:#?}",
290 self.ret
291 ),
292 };
293
294 for arg in self.args.iter() {
295 let arg_type = match arg.mode {
296 PassMode::Ignore => continue,
297 PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
298 PassMode::Pair(_, _) => {
299 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
300 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
301 continue;
302 }
303 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
304 span,
305 "query hooks should've made this `PassMode` impossible: {:#?}",
306 arg
307 ),
308 };
309 argument_types.push(arg_type);
310 }
311
312 SpirvType::Function {
313 return_type,
314 arguments: &argument_types,
315 }
316 .def(span, cx)
317 }
318}
319
320impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
321 fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
322 if let TyKind::Adt(adt, args) = *self.ty.kind() {
323 if span == DUMMY_SP {
324 span = cx.tcx.def_span(adt.did());
325 }
326
327 let attrs = AggregatedSpirvAttributes::parse(
328 cx,
329 cx.tcx
330 .get_attrs_by_path(
331 adt.did(),
332 &[cx.sym.rust_gpu, cx.sym.spirv_attr_with_version],
333 )
334 .chain(cx.tcx.get_attrs_by_path(
335 adt.did(),
336 &[cx.sym.rust_gpu, cx.sym.vector, cx.sym.v1],
337 )),
338 );
339
340 if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
341 && let Ok(spirv_type) =
342 trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
343 {
344 return spirv_type;
345 }
346 }
347
348 match self.backend_repr {
352 _ if self.uninhabited => SpirvType::Adt {
353 def_id: def_id_for_spirv_type_adt(*self),
354 size: Some(Size::ZERO),
355 align: Align::from_bytes(0).unwrap(),
356 field_types: &[],
357 field_offsets: &[],
358 field_names: None,
359 }
360 .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
361 BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
362 BackendRepr::ScalarPair(a, b) => {
363 let mut non_zst_fields = (0..self.fields.count())
379 .map(|i| (i, self.field(cx, i)))
380 .filter(|(_, field)| !field.is_zst());
381 let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
382 (Some(field), None) => Some(field),
383 _ => None,
384 };
385 if let Some((i, field)) = sole_non_zst_field {
386 if self.fields.offset(i) == Size::ZERO
389 && field.size == self.size
390 && field.align.abi == self.align.abi
391 && field.backend_repr.eq_up_to_validity(&self.backend_repr)
392 {
393 return field.spirv_type(span, cx);
394 }
395 }
396
397 let a_offset = Size::ZERO;
400 let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
401 let a = trans_scalar(cx, span, *self, a, a_offset);
402 let b = trans_scalar(cx, span, *self, b, b_offset);
403 let size = if self.is_unsized() {
404 None
405 } else {
406 Some(self.size)
407 };
408 let mut field_names = Vec::new();
410 if let TyKind::Adt(adt, _) = self.ty.kind()
411 && let Variants::Single { index } = self.variants
412 {
413 for i in self.fields.index_by_increasing_offset() {
414 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
415 field_names.push(field.name);
416 }
417 }
418 SpirvType::Adt {
419 def_id: def_id_for_spirv_type_adt(*self),
420 size,
421 align: self.align.abi,
422 field_types: &[a, b],
423 field_offsets: &[a_offset, b_offset],
424 field_names: if field_names.len() == 2 {
425 Some(&field_names)
426 } else {
427 None
428 },
429 }
430 .def_with_name(cx, span, TyLayoutNameKey::from(*self))
431 }
432 BackendRepr::SimdVector { element, count } => {
433 let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
434 SpirvType::Vector {
435 element: elem_spirv,
436 count: count as u32,
437 size: self.size,
438 align: self.align.abi,
439 }
440 .def(span, cx)
441 }
442 BackendRepr::SimdScalableVector { .. } => cx
443 .tcx
444 .dcx()
445 .fatal("scalable vectors are not supported in SPIR-V backend"),
446 BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
447 }
448 }
449}
450
451pub fn scalar_pair_element_backend_type<'tcx>(
454 cx: &CodegenCx<'tcx>,
455 span: Span,
456 ty: TyAndLayout<'tcx>,
457 index: usize,
458) -> Word {
459 let [a, b] = match ty.layout.backend_repr() {
460 BackendRepr::ScalarPair(a, b) => [a, b],
461 other => span_bug!(
462 span,
463 "scalar_pair_element_backend_type invalid abi: {:?}",
464 other
465 ),
466 };
467 let offset = match index {
468 0 => Size::ZERO,
469 1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
470 _ => unreachable!(),
471 };
472 trans_scalar(cx, span, ty, [a, b][index], offset)
473}
474
475fn trans_scalar<'tcx>(
483 cx: &CodegenCx<'tcx>,
484 span: Span,
485 ty: TyAndLayout<'tcx>,
486 scalar: Scalar,
487 offset: Size,
488) -> Word {
489 if scalar.is_bool() {
490 return SpirvType::Bool.def(span, cx);
491 }
492
493 match scalar.primitive() {
494 Primitive::Int(int_kind, signedness) => {
495 SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
496 }
497 Primitive::Float(float_kind) => {
498 SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
499 }
500 Primitive::Pointer(_) => {
501 let pointee_ty = dig_scalar_pointee(cx, ty, offset);
502 if let Some(predefined_result) = cx
505 .type_cache
506 .recursive_pointee_cache
507 .begin(cx, span, pointee_ty)
508 {
509 predefined_result
510 } else {
511 let pointee = pointee_ty.spirv_type(span, cx);
512 cx.type_cache
513 .recursive_pointee_cache
514 .end(cx, span, pointee_ty, pointee)
515 }
516 }
517 }
518}
519
520fn dig_scalar_pointee<'tcx>(
531 cx: &CodegenCx<'tcx>,
532 layout: TyAndLayout<'tcx>,
533 offset: Size,
534) -> PointeeTy<'tcx> {
535 if let FieldsShape::Primitive = layout.fields {
536 assert_eq!(offset, Size::ZERO);
537 let pointee = match *layout.ty.kind() {
538 TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
539 PointeeTy::Ty(cx.layout_of(pointee_ty))
540 }
541 TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
542 _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
543 };
544 return pointee;
545 }
546
547 let all_fields = (match &layout.variants {
548 Variants::Empty => 0..0,
549 Variants::Multiple { variants, .. } => 0..variants.len(),
550 Variants::Single { index } => {
551 let i = index.as_usize();
552 i..i + 1
553 }
554 })
555 .flat_map(|variant_idx| {
556 let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
557 (0..variant.fields.count()).map(move |field_idx| {
558 (
559 variant.field(cx, field_idx),
560 variant.fields.offset(field_idx),
561 )
562 })
563 });
564
565 let mut pointee = None;
566 for (field, field_offset) in all_fields {
567 if field.is_zst() {
568 continue;
569 }
570 if (field_offset..field_offset + field.size).contains(&offset) {
571 let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
572 match pointee {
573 Some(old_pointee) if old_pointee != new_pointee => {
574 cx.tcx.dcx().fatal(format!(
575 "dig_scalar_pointee: unsupported Pointer with different \
576 pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
577 ));
578 }
579 _ => pointee = Some(new_pointee),
580 }
581 }
582 }
583 pointee.unwrap_or_else(|| {
584 bug!(
585 "field containing Pointer scalar at offset {:?} not found in {:#?}",
586 offset,
587 layout
588 )
589 })
590}
591
592fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
595 fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
596 assert_eq!(ty.size, Size::ZERO);
597 SpirvType::Adt {
598 def_id: def_id_for_spirv_type_adt(ty),
599 size: Some(Size::ZERO),
600 align: ty.align.abi,
601 field_types: &[],
602 field_offsets: &[],
603 field_names: None,
604 }
605 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
606 }
607 match ty.fields {
608 FieldsShape::Primitive => span_bug!(
609 span,
610 "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
611 ty
612 ),
613 FieldsShape::Union(_) => {
614 assert!(!ty.is_unsized(), "{ty:#?}");
615
616 let largest_case = (0..ty.fields.count())
622 .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
623 .max_by_key(|(_, case)| case.size);
624
625 if let Some((case_idx, case)) = largest_case {
626 if ty.align != case.align {
627 trans_struct_or_union(cx, span, ty, Some(case_idx))
629 } else {
630 assert_eq!(ty.size, case.size);
631 case.spirv_type(span, cx)
632 }
633 } else {
634 create_zst(cx, span, ty)
635 }
636 }
637 FieldsShape::Array { stride, count } => {
638 let element_type = ty.field(cx, 0).spirv_type(span, cx);
639 if ty.is_unsized() {
640 assert_eq!(count, 0);
643 SpirvType::RuntimeArray {
644 element: element_type,
645 }
646 .def(span, cx)
647 } else if count == 0 {
648 create_zst(cx, span, ty)
650 } else {
651 let count_const = cx.constant_u32(span, count as u32);
652 let element_spv = cx.lookup_type(element_type);
653 let stride_spv = element_spv
654 .sizeof(cx)
655 .expect("Unexpected unsized type in sized FieldsShape::Array")
656 .align_to(element_spv.alignof(cx));
657 assert_eq!(stride_spv, stride);
658 SpirvType::Array {
659 element: element_type,
660 count: count_const,
661 }
662 .def(span, cx)
663 }
664 }
665 FieldsShape::Arbitrary { .. } => trans_struct_or_union(cx, span, ty, None),
666 }
667}
668
669pub fn auto_struct_layout(
671 cx: &CodegenCx<'_>,
672 field_types: &[Word],
673) -> (Vec<Size>, Option<Size>, Align) {
674 let mut field_offsets = Vec::with_capacity(field_types.len());
676 let mut offset = Some(Size::ZERO);
677 let mut max_align = Align::from_bytes(0).unwrap();
678 for &field_type in field_types {
679 let spirv_type = cx.lookup_type(field_type);
680 let field_size = spirv_type.sizeof(cx);
681 let field_align = spirv_type.alignof(cx);
682 let this_offset = offset
683 .expect("Unsized values can only be the last field in a struct")
684 .align_to(field_align);
685
686 field_offsets.push(this_offset);
687 if field_align > max_align {
688 max_align = field_align;
689 }
690 offset = field_size.map(|size| this_offset + size);
691 }
692 (field_offsets, offset, max_align)
693}
694
695fn trans_struct_or_union<'tcx>(
697 cx: &CodegenCx<'tcx>,
698 span: Span,
699 ty: TyAndLayout<'tcx>,
700 union_case: Option<FieldIdx>,
701) -> Word {
702 let size = if ty.is_unsized() { None } else { Some(ty.size) };
703 let align = ty.align.abi;
704 let mut field_types = Vec::new();
706 let mut field_offsets = Vec::new();
707 let mut field_names = Vec::new();
708 for i in ty.fields.index_by_increasing_offset() {
709 if let Some(expected_field_idx) = union_case
710 && i != expected_field_idx.as_usize()
711 {
712 continue;
713 }
714
715 let field_ty = ty.field(cx, i);
716 field_types.push(field_ty.spirv_type(span, cx));
717 let offset = ty.fields.offset(i);
718 field_offsets.push(offset);
719 if let Variants::Single { index } = ty.variants {
720 if let TyKind::Adt(adt, _) = ty.ty.kind() {
721 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
722 field_names.push(field.name);
723 } else {
724 field_names.push(Symbol::intern(&format!("{i}")));
726 }
727 } else {
728 if let TyKind::Adt(_, _) = ty.ty.kind() {
729 } else {
730 span_bug!(span, "Variants::Multiple not TyKind::Adt");
731 }
732 if i == 0 {
733 field_names.push(cx.sym.discriminant);
734 } else {
735 cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
736 }
737 };
738 }
739 SpirvType::Adt {
740 def_id: def_id_for_spirv_type_adt(ty),
741 size,
742 align,
743 field_types: &field_types,
744 field_offsets: &field_offsets,
745 field_names: Some(&field_names),
746 }
747 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
748}
749
750fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
754 match *layout.ty.kind() {
755 TyKind::Adt(def, _) => Some(def.did()),
756 TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
757 Some(def_id)
758 }
759 _ => None,
760 }
761}
762
763fn span_for_spirv_type_adt(cx: &CodegenCx<'_>, layout: TyAndLayout<'_>) -> Option<Span> {
764 def_id_for_spirv_type_adt(layout).map(|did| cx.tcx.def_span(did))
765}
766
767#[derive(Copy, Clone, PartialEq, Eq, Hash)]
770pub struct TyLayoutNameKey<'tcx> {
771 ty: Ty<'tcx>,
772 variant: Option<VariantIdx>,
773}
774
775impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
776 fn from(layout: TyAndLayout<'tcx>) -> Self {
777 TyLayoutNameKey {
778 ty: layout.ty,
779 variant: match layout.variants {
780 Variants::Single { index } => Some(index),
781 _ => None,
782 },
783 }
784 }
785}
786
787impl fmt::Display for TyLayoutNameKey<'_> {
788 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
789 write!(f, "{}", self.ty)?;
790 if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
791 && def.is_enum()
792 && !def.variants().is_empty()
793 {
794 write!(f, "::{}", def.variants()[index].name)?;
795 }
796 if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
797 write!(f, "::{}", CoroutineArgs::variant_name(index))?;
798 }
799 Ok(())
800 }
801}
802
803fn trans_intrinsic_type<'tcx>(
804 cx: &CodegenCx<'tcx>,
805 span: Span,
806 ty: TyAndLayout<'tcx>,
807 args: GenericArgsRef<'tcx>,
808 intrinsic_type_attr: IntrinsicType,
809) -> Result<Word, ErrorGuaranteed> {
810 match intrinsic_type_attr {
811 IntrinsicType::GenericImageType => {
812 if ty.size != Size::from_bytes(4) {
814 return Err(cx
815 .tcx
816 .dcx()
817 .err("#[spirv(generic_image)] type must have size 4"));
818 }
819
820 let sampled_type = match args.type_at(0).kind() {
833 TyKind::Int(int) => match int {
834 IntTy::Isize => {
835 SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, true)
836 .def(span, cx)
837 }
838 IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
839 IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
840 IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
841 IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
842 IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
843 },
844 TyKind::Uint(uint) => match uint {
845 UintTy::Usize => {
846 SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, false)
847 .def(span, cx)
848 }
849 UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
850 UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
851 UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
852 UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
853 UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
854 },
855 TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
856 TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
857 _ => {
858 return Err(cx
859 .tcx
860 .dcx()
861 .span_err(span, "Invalid sampled type to `Image`."));
862 }
863 };
864
865 trait FromScalarInt: Sized {
874 fn from_scalar_int(n: ScalarInt) -> Option<Self>;
875 }
876
877 impl FromScalarInt for u32 {
878 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
879 Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
880 }
881 }
882
883 impl FromScalarInt for Dim {
884 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
885 Dim::from_u32(u32::from_scalar_int(n)?)
886 }
887 }
888
889 impl FromScalarInt for ImageFormat {
890 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
891 ImageFormat::from_u32(u32::from_scalar_int(n)?)
892 }
893 }
894
895 fn const_int_value<'tcx, P: FromScalarInt>(
896 cx: &CodegenCx<'tcx>,
897 const_: Const<'tcx>,
898 ) -> Result<P, ErrorGuaranteed> {
899 let ty::Value {
900 ty: const_ty,
901 valtree: const_val,
902 } = const_.to_value();
903 assert!(const_ty.is_integral());
904 const_val
905 .try_to_scalar()
906 .and_then(|scalar| scalar.try_to_scalar_int().ok())
907 .and_then(P::from_scalar_int)
908 .ok_or_else(|| {
909 cx.tcx
910 .dcx()
911 .err(format!("invalid value for Image const generic: {const_}"))
912 })
913 }
914
915 let dim = const_int_value(cx, args.const_at(1))?;
916 let depth = const_int_value(cx, args.const_at(2))?;
917 let arrayed = const_int_value(cx, args.const_at(3))?;
918 let multisampled = const_int_value(cx, args.const_at(4))?;
919 let sampled = const_int_value(cx, args.const_at(5))?;
920 let image_format = const_int_value(cx, args.const_at(6))?;
921
922 let ty = SpirvType::Image {
923 sampled_type,
924 dim,
925 depth,
926 arrayed,
927 multisampled,
928 sampled,
929 image_format,
930 };
931 Ok(ty.def(span, cx))
932 }
933 IntrinsicType::Sampler => {
934 if ty.size != Size::from_bytes(4) {
936 return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
937 }
938 Ok(SpirvType::Sampler.def(span, cx))
939 }
940 IntrinsicType::AccelerationStructureKhr => {
941 Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
942 }
943 IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
944 IntrinsicType::SampledImage => {
945 if ty.size != Size::from_bytes(4) {
947 return Err(cx
948 .tcx
949 .dcx()
950 .err("#[spirv(sampled_image)] type must have size 4"));
951 }
952
953 if let Some(image_ty) = args.types().next() {
956 let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
958 Ok(SpirvType::SampledImage { image_type }.def(span, cx))
959 } else {
960 Err(cx
961 .tcx
962 .dcx()
963 .err("#[spirv(sampled_image)] type must have a generic image type"))
964 }
965 }
966 IntrinsicType::RuntimeArray => {
967 if ty.size != Size::from_bytes(4) {
968 return Err(cx
969 .tcx
970 .dcx()
971 .err("#[spirv(runtime_array)] type must have size 4"));
972 }
973
974 if let Some(elem_ty) = args.types().next() {
977 Ok(SpirvType::RuntimeArray {
978 element: cx.layout_of(elem_ty).spirv_type(span, cx),
979 }
980 .def(span, cx))
981 } else {
982 Err(cx
983 .tcx
984 .dcx()
985 .err("#[spirv(runtime_array)] type must have a generic element type"))
986 }
987 }
988 IntrinsicType::TypedBuffer => {
989 if ty.size != Size::from_bytes(4) {
990 return Err(cx
991 .tcx
992 .sess
993 .dcx()
994 .err("#[spirv(typed_buffer)] type must have size 4"));
995 }
996
997 if let Some(data_ty) = args.types().next() {
1000 Ok(SpirvType::InterfaceBlock {
1005 inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
1006 }
1007 .def(span, cx))
1008 } else {
1009 Err(cx
1010 .tcx
1011 .sess
1012 .dcx()
1013 .err("#[spirv(typed_buffer)] type must have a generic data type"))
1014 }
1015 }
1016 IntrinsicType::Matrix => {
1017 let span = span_for_spirv_type_adt(cx, ty).unwrap();
1018 let err_attr_name = "`#[spirv(matrix)]`";
1019 let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
1020 match cx.lookup_type(element) {
1021 SpirvType::Vector { .. } => (),
1022 ty => {
1023 return Err(cx
1024 .tcx
1025 .dcx()
1026 .struct_span_err(
1027 span,
1028 format!("{err_attr_name} type fields must all be vectors"),
1029 )
1030 .with_note(format!("field type is {}", ty.debug(element, cx)))
1031 .emit());
1032 }
1033 }
1034 Ok(SpirvType::Matrix { element, count }.def(span, cx))
1035 }
1036 IntrinsicType::Vector => {
1037 let span = span_for_spirv_type_adt(cx, ty).unwrap();
1038 let err_attr_name = "`#[spirv(vector)]`";
1039 let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
1040 match cx.lookup_type(element) {
1041 SpirvType::Bool | SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1042 ty => {
1043 return Err(cx
1044 .tcx
1045 .dcx()
1046 .struct_span_err(
1047 span,
1048 format!(
1049 "{err_attr_name} type fields must all be floats, integers or bools"
1050 ),
1051 )
1052 .with_note(format!("field type is {}", ty.debug(element, cx)))
1053 .emit());
1054 }
1055 }
1056 Ok(SpirvType::Vector {
1057 element,
1058 count,
1059 size: ty.size,
1060 align: ty.align.abi,
1061 }
1062 .def(span, cx))
1063 }
1064 }
1065}
1066
1067fn trans_glam_like_struct<'tcx>(
1070 cx: &CodegenCx<'tcx>,
1071 span: Span,
1072 ty: TyAndLayout<'tcx>,
1073 args: GenericArgsRef<'tcx>,
1074 err_attr_name: &str,
1075) -> Result<(Word, u32), ErrorGuaranteed> {
1076 let tcx = cx.tcx;
1077 if let Some(adt) = ty.ty.ty_adt_def()
1078 && adt.is_struct()
1079 {
1080 let (count, element) = adt
1081 .non_enum_variant()
1082 .fields
1083 .iter()
1084 .map(|f| f.ty(tcx, args))
1085 .dedup_with_count()
1086 .exactly_one()
1087 .map_err(|_e| {
1088 tcx.dcx().span_err(
1089 span,
1090 format!("{err_attr_name} member types must all be the same"),
1091 )
1092 })?;
1093
1094 let element = cx.layout_of(element);
1095 let element_word = element.spirv_type(span, cx);
1096 let count = u32::try_from(count)
1097 .ok()
1098 .filter(|count| 2 <= *count && *count <= 4)
1099 .ok_or_else(|| {
1100 tcx.dcx()
1101 .span_err(span, format!("{err_attr_name} must have 2, 3 or 4 members"))
1102 })?;
1103
1104 for i in 0..ty.fields.count() {
1105 let expected = element.size.checked_mul(i as u64, cx).unwrap();
1106 let actual = ty.fields.offset(i);
1107 if actual != expected {
1108 let name: &str = adt
1109 .non_enum_variant()
1110 .fields
1111 .get(FieldIdx::from(i))
1112 .unwrap()
1113 .name
1114 .as_str();
1115 tcx.dcx().span_fatal(
1116 span,
1117 format!(
1118 "Unexpected layout for {err_attr_name} annotated struct: \
1119 Expected member `{name}` at offset {expected:?}, but was at {actual:?}"
1120 ),
1121 )
1122 }
1123 }
1124
1125 Ok((element_word, count))
1126 } else {
1127 Err(tcx
1128 .dcx()
1129 .span_err(span, format!("{err_attr_name} type must be a struct")))
1130 }
1131}