1use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11 Align, BackendRepr, FieldIdx, FieldsShape, HasDataLayout as _, LayoutData, Primitive,
12 ReprFlags, ReprOptions, Scalar, Size, TagEncoding, VariantIdx, Variants,
13};
14use rustc_data_structures::fx::FxHashMap;
15use rustc_errors::ErrorGuaranteed;
16use rustc_hashes::Hash64;
17use rustc_index::Idx;
18use rustc_middle::query::Providers;
19use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
20use rustc_middle::ty::{
21 self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
22 TyKind, UintTy,
23};
24use rustc_middle::ty::{GenericArgsRef, ScalarInt};
25use rustc_middle::{bug, span_bug};
26use rustc_span::DUMMY_SP;
27use rustc_span::def_id::DefId;
28use rustc_span::{Span, Symbol};
29use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
30use std::cell::RefCell;
31use std::collections::hash_map::Entry;
32use std::fmt;
33
34pub(crate) fn provide(providers: &mut Providers) {
35 providers.fn_sig = |tcx, def_id| {
45 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_sig)(tcx, def_id);
48 result.map_bound(|outer| {
49 outer.map_bound(|mut inner| {
50 if let Abi::C { .. } = inner.abi {
51 inner.abi = Abi::Unadjusted;
52 }
53 inner
54 })
55 })
56 };
57
58 fn readjust_fn_abi<'tcx>(
63 tcx: TyCtxt<'tcx>,
64 fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
65 ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
66 let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
67 let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _, _| ArgAttributes::new());
68 arg.make_direct_deprecated();
71
72 if arg.layout.is_zst() {
74 arg.mode = PassMode::Ignore;
75 }
76
77 arg
78 };
79 tcx.arena.alloc(FnAbi {
80 args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
81 ret: readjust_arg_abi(&fn_abi.ret),
82
83 c_variadic: fn_abi.c_variadic,
87 fixed_count: fn_abi.fixed_count,
88 conv: fn_abi.conv,
89 can_unwind: fn_abi.can_unwind,
90 })
91 }
92 providers.fn_abi_of_fn_ptr = |tcx, key| {
93 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_fn_ptr)(tcx, key);
94 Ok(readjust_fn_abi(tcx, result?))
95 };
96 providers.fn_abi_of_instance = |tcx, key| {
97 let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_instance)(tcx, key);
98 Ok(readjust_fn_abi(tcx, result?))
99 };
100
101 fn clone_layout<FieldIdx: Idx, VariantIdx: Idx>(
103 layout: &LayoutData<FieldIdx, VariantIdx>,
104 ) -> LayoutData<FieldIdx, VariantIdx> {
105 let LayoutData {
106 ref fields,
107 ref variants,
108 backend_repr,
109 largest_niche,
110 uninhabited,
111 align,
112 size,
113 max_repr_align,
114 unadjusted_abi_align,
115 randomization_seed,
116 } = *layout;
117 LayoutData {
118 fields: match *fields {
119 FieldsShape::Primitive => FieldsShape::Primitive,
120 FieldsShape::Union(count) => FieldsShape::Union(count),
121 FieldsShape::Array { stride, count } => FieldsShape::Array { stride, count },
122 FieldsShape::Arbitrary {
123 ref offsets,
124 ref memory_index,
125 } => FieldsShape::Arbitrary {
126 offsets: offsets.clone(),
127 memory_index: memory_index.clone(),
128 },
129 },
130 variants: match *variants {
131 Variants::Empty => Variants::Empty,
132 Variants::Single { index } => Variants::Single { index },
133 Variants::Multiple {
134 tag,
135 ref tag_encoding,
136 tag_field,
137 ref variants,
138 } => Variants::Multiple {
139 tag,
140 tag_encoding: match *tag_encoding {
141 TagEncoding::Direct => TagEncoding::Direct,
142 TagEncoding::Niche {
143 untagged_variant,
144 ref niche_variants,
145 niche_start,
146 } => TagEncoding::Niche {
147 untagged_variant,
148 niche_variants: niche_variants.clone(),
149 niche_start,
150 },
151 },
152 tag_field,
153 variants: variants.clone(),
154 },
155 },
156 backend_repr,
157 largest_niche,
158 uninhabited,
159 align,
160 size,
161 max_repr_align,
162 unadjusted_abi_align,
163 randomization_seed,
164 }
165 }
166
167 providers.layout_of = |tcx, key| {
168 let ty = key.value;
174
175 let reimplement_old_style_repr_simd = match ty.kind() {
178 ty::Adt(def, args) if def.repr().simd() && !def.repr().packed() && def.is_struct() => {
179 Some(def.non_enum_variant()).and_then(|v| {
180 let (count, e_ty) = v
181 .fields
182 .iter()
183 .map(|f| f.ty(tcx, args))
184 .dedup_with_count()
185 .exactly_one()
186 .ok()?;
187 let e_len = u64::try_from(count).ok().filter(|&e_len| e_len > 1)?;
188 Some((def, e_ty, e_len))
189 })
190 }
191 _ => None,
192 };
193
194 if let Some((adt_def, e_ty, e_len)) = reimplement_old_style_repr_simd {
197 let cx = rustc_middle::ty::layout::LayoutCx::new(
198 tcx,
199 key.typing_env.with_post_analysis_normalized(tcx),
200 );
201 let dl = cx.data_layout();
202
203 let e_ly = cx.layout_of(e_ty)?;
205 let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
206 tcx.dcx().span_fatal(
209 tcx.def_span(adt_def.did()),
210 format!(
211 "SIMD type `{ty}` with a non-primitive-scalar \
212 (integer/float/pointer) element type `{}`",
213 e_ly.ty
214 ),
215 );
216 };
217
218 let size = e_ly.size.checked_mul(e_len, dl).unwrap();
220 let align = dl.llvmlike_vector_align(size);
221 let size = size.align_to(align.abi);
222
223 let layout = tcx.mk_layout(LayoutData {
224 variants: Variants::Single {
225 index: rustc_abi::FIRST_VARIANT,
226 },
227 fields: FieldsShape::Array {
228 stride: e_ly.size,
229 count: e_len,
230 },
231 backend_repr: BackendRepr::SimdVector {
232 element: e_repr,
233 count: e_len,
234 },
235 largest_niche: e_ly.largest_niche,
236 uninhabited: false,
237 size,
238 align,
239 max_repr_align: None,
240 unadjusted_abi_align: align.abi,
241 randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
242 });
243
244 return Ok(TyAndLayout { ty, layout });
245 }
246
247 let TyAndLayout { ty, mut layout } =
248 (rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;
249
250 #[allow(clippy::match_like_matches_macro)]
251 let hide_niche = match ty.kind() {
252 ty::Bool => {
253 let libcore_needs_bool_niche = true;
257
258 !libcore_needs_bool_niche
259 }
260 _ => false,
261 };
262
263 if hide_niche {
264 layout = tcx.mk_layout(LayoutData {
265 largest_niche: None,
266 ..clone_layout(layout.0.0)
267 });
268 }
269
270 Ok(TyAndLayout { ty, layout })
271 };
272
273 providers.check_well_formed = |tcx, def_id| {
290 let trivial_struct = match tcx.hir_node_by_def_id(def_id) {
291 rustc_hir::Node::Item(item) => match item.kind {
292 rustc_hir::ItemKind::Struct(
293 _,
294 &rustc_hir::Generics {
295 params:
296 &[]
297 | &[
298 rustc_hir::GenericParam {
299 kind:
300 rustc_hir::GenericParamKind::Type {
301 default: None,
302 synthetic: false,
303 },
304 ..
305 },
306 ],
307 predicates: &[],
308 has_where_clause_predicates: false,
309 where_clause_span: _,
310 span: _,
311 },
312 _,
313 ) => Some(tcx.adt_def(def_id)),
314 _ => None,
315 },
316 _ => None,
317 };
318 let valid_non_array_simd_struct = trivial_struct.is_some_and(|adt_def| {
319 let ReprOptions {
320 int: None,
321 align: None,
322 pack: None,
323 flags: ReprFlags::IS_SIMD,
324 field_shuffle_seed: _,
325 } = adt_def.repr()
326 else {
327 return false;
328 };
329 if adt_def.destructor(tcx).is_some() {
330 return false;
331 }
332
333 let field_types = adt_def
334 .non_enum_variant()
335 .fields
336 .iter()
337 .map(|f| tcx.type_of(f.did).instantiate_identity());
338 field_types.dedup().exactly_one().is_ok_and(|elem_ty| {
339 matches!(
340 elem_ty.kind(),
341 ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Param(_)
342 )
343 })
344 });
345
346 if valid_non_array_simd_struct {
347 tcx.dcx()
348 .struct_span_warn(
349 tcx.def_span(def_id),
350 "[Rust-GPU] temporarily re-allowing old-style `#[repr(simd)]` (with fields)",
351 )
352 .with_note("removed upstream by https://github.com/rust-lang/rust/pull/129403")
353 .with_note("in favor of the new `#[repr(simd)] struct TxN([T; N]);` style")
354 .with_note("(taking effect since `nightly-2024-09-12` / `1.83.0` stable)")
355 .emit();
356 return Ok(());
357 }
358
359 (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_well_formed)(tcx, def_id)
360 };
361
362 providers.check_mono_item = |_, _| {};
369}
370
371#[derive(Default)]
377pub struct RecursivePointeeCache<'tcx> {
378 map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
379}
380
381impl<'tcx> RecursivePointeeCache<'tcx> {
382 fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
383 match self.map.borrow_mut().entry(pointee) {
384 Entry::Vacant(entry) => {
387 entry.insert(PointeeDefState::Defining);
388 None
389 }
390 Entry::Occupied(mut entry) => match *entry.get() {
391 PointeeDefState::Defining => {
395 let new_id = cx.emit_global().id();
396 cx.emit_global()
400 .type_forward_pointer(new_id, StorageClass::Generic);
401 entry.insert(PointeeDefState::DefiningWithForward(new_id));
402 cx.zombie_with_span(
403 new_id,
404 span,
405 "cannot create self-referential types, even through pointers",
406 );
407 Some(new_id)
408 }
409 PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
413 },
414 }
415 }
416
417 fn end(
418 &self,
419 cx: &CodegenCx<'tcx>,
420 span: Span,
421 pointee: PointeeTy<'tcx>,
422 pointee_spv: Word,
423 ) -> Word {
424 match self.map.borrow_mut().entry(pointee) {
425 Entry::Vacant(_) => {
427 span_bug!(span, "RecursivePointeeCache::end should always have entry")
428 }
429 Entry::Occupied(mut entry) => match *entry.get() {
430 PointeeDefState::Defining => {
433 let id = SpirvType::Pointer {
434 pointee: pointee_spv,
435 }
436 .def(span, cx);
437 entry.insert(PointeeDefState::Defined(id));
438 id
439 }
440 PointeeDefState::DefiningWithForward(id) => {
443 entry.insert(PointeeDefState::Defined(id));
444 SpirvType::Pointer {
445 pointee: pointee_spv,
446 }
447 .def_with_id(cx, span, id)
448 }
449 PointeeDefState::Defined(_) => {
450 span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
451 }
452 },
453 }
454 }
455}
456
457#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
458enum PointeeTy<'tcx> {
459 Ty(TyAndLayout<'tcx>),
460 Fn(PolyFnSig<'tcx>),
461}
462
463impl fmt::Display for PointeeTy<'_> {
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 match self {
466 PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
467 PointeeTy::Fn(ty) => write!(f, "{ty}"),
468 }
469 }
470}
471
472enum PointeeDefState {
473 Defining,
474 DefiningWithForward(Word),
475 Defined(Word),
476}
477
478pub trait ConvSpirvType<'tcx> {
481 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
482}
483
484impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
485 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
486 match *self {
487 PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
488 PointeeTy::Fn(ty) => cx
489 .fn_abi_of_fn_ptr(ty, ty::List::empty())
490 .spirv_type(span, cx),
491 }
492 }
493}
494
495impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
496 fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
497 let mut argument_types = Vec::new();
499
500 let return_type = match self.ret.mode {
501 PassMode::Ignore => SpirvType::Void.def(span, cx),
502 PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
503 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
504 span,
505 "query hooks should've made this `PassMode` impossible: {:#?}",
506 self.ret
507 ),
508 };
509
510 for arg in self.args.iter() {
511 let arg_type = match arg.mode {
512 PassMode::Ignore => continue,
513 PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
514 PassMode::Pair(_, _) => {
515 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
516 argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
517 continue;
518 }
519 PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
520 span,
521 "query hooks should've made this `PassMode` impossible: {:#?}",
522 arg
523 ),
524 };
525 argument_types.push(arg_type);
526 }
527
528 SpirvType::Function {
529 return_type,
530 arguments: &argument_types,
531 }
532 .def(span, cx)
533 }
534}
535
536impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
537 fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
538 if let TyKind::Adt(adt, args) = *self.ty.kind() {
539 if span == DUMMY_SP {
540 span = cx.tcx.def_span(adt.did());
541 }
542
543 let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
544
545 if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
546 && let Ok(spirv_type) =
547 trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
548 {
549 return spirv_type;
550 }
551 }
552
553 match self.backend_repr {
557 _ if self.uninhabited => SpirvType::Adt {
558 def_id: def_id_for_spirv_type_adt(*self),
559 size: Some(Size::ZERO),
560 align: Align::from_bytes(0).unwrap(),
561 field_types: &[],
562 field_offsets: &[],
563 field_names: None,
564 }
565 .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
566 BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
567 BackendRepr::ScalarPair(a, b) => {
568 let mut non_zst_fields = (0..self.fields.count())
584 .map(|i| (i, self.field(cx, i)))
585 .filter(|(_, field)| !field.is_zst());
586 let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
587 (Some(field), None) => Some(field),
588 _ => None,
589 };
590 if let Some((i, field)) = sole_non_zst_field {
591 if self.fields.offset(i) == Size::ZERO
594 && field.size == self.size
595 && field.align.abi == self.align.abi
596 && field.backend_repr.eq_up_to_validity(&self.backend_repr)
597 {
598 return field.spirv_type(span, cx);
599 }
600 }
601
602 let a_offset = Size::ZERO;
605 let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
606 let a = trans_scalar(cx, span, *self, a, a_offset);
607 let b = trans_scalar(cx, span, *self, b, b_offset);
608 let size = if self.is_unsized() {
609 None
610 } else {
611 Some(self.size)
612 };
613 let mut field_names = Vec::new();
615 if let TyKind::Adt(adt, _) = self.ty.kind()
616 && let Variants::Single { index } = self.variants
617 {
618 for i in self.fields.index_by_increasing_offset() {
619 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
620 field_names.push(field.name);
621 }
622 }
623 SpirvType::Adt {
624 def_id: def_id_for_spirv_type_adt(*self),
625 size,
626 align: self.align.abi,
627 field_types: &[a, b],
628 field_offsets: &[a_offset, b_offset],
629 field_names: if field_names.len() == 2 {
630 Some(&field_names)
631 } else {
632 None
633 },
634 }
635 .def_with_name(cx, span, TyLayoutNameKey::from(*self))
636 }
637 BackendRepr::SimdVector { element, count } => {
638 let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
639 SpirvType::Vector {
640 element: elem_spirv,
641 count: count as u32,
642 }
643 .def(span, cx)
644 }
645 BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
646 }
647 }
648}
649
650pub fn scalar_pair_element_backend_type<'tcx>(
653 cx: &CodegenCx<'tcx>,
654 span: Span,
655 ty: TyAndLayout<'tcx>,
656 index: usize,
657) -> Word {
658 let [a, b] = match ty.layout.backend_repr() {
659 BackendRepr::ScalarPair(a, b) => [a, b],
660 other => span_bug!(
661 span,
662 "scalar_pair_element_backend_type invalid abi: {:?}",
663 other
664 ),
665 };
666 let offset = match index {
667 0 => Size::ZERO,
668 1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
669 _ => unreachable!(),
670 };
671 trans_scalar(cx, span, ty, [a, b][index], offset)
672}
673
674fn trans_scalar<'tcx>(
682 cx: &CodegenCx<'tcx>,
683 span: Span,
684 ty: TyAndLayout<'tcx>,
685 scalar: Scalar,
686 offset: Size,
687) -> Word {
688 if scalar.is_bool() {
689 return SpirvType::Bool.def(span, cx);
690 }
691
692 match scalar.primitive() {
693 Primitive::Int(int_kind, signedness) => {
694 SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
695 }
696 Primitive::Float(float_kind) => {
697 SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
698 }
699 Primitive::Pointer(_) => {
700 let pointee_ty = dig_scalar_pointee(cx, ty, offset);
701 if let Some(predefined_result) = cx
704 .type_cache
705 .recursive_pointee_cache
706 .begin(cx, span, pointee_ty)
707 {
708 predefined_result
709 } else {
710 let pointee = pointee_ty.spirv_type(span, cx);
711 cx.type_cache
712 .recursive_pointee_cache
713 .end(cx, span, pointee_ty, pointee)
714 }
715 }
716 }
717}
718
719fn dig_scalar_pointee<'tcx>(
730 cx: &CodegenCx<'tcx>,
731 layout: TyAndLayout<'tcx>,
732 offset: Size,
733) -> PointeeTy<'tcx> {
734 if let FieldsShape::Primitive = layout.fields {
735 assert_eq!(offset, Size::ZERO);
736 let pointee = match *layout.ty.kind() {
737 TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
738 PointeeTy::Ty(cx.layout_of(pointee_ty))
739 }
740 TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
741 _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
742 };
743 return pointee;
744 }
745
746 let all_fields = (match &layout.variants {
747 Variants::Empty => 0..0,
748 Variants::Multiple { variants, .. } => 0..variants.len(),
749 Variants::Single { index } => {
750 let i = index.as_usize();
751 i..i + 1
752 }
753 })
754 .flat_map(|variant_idx| {
755 let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
756 (0..variant.fields.count()).map(move |field_idx| {
757 (
758 variant.field(cx, field_idx),
759 variant.fields.offset(field_idx),
760 )
761 })
762 });
763
764 let mut pointee = None;
765 for (field, field_offset) in all_fields {
766 if field.is_zst() {
767 continue;
768 }
769 if (field_offset..field_offset + field.size).contains(&offset) {
770 let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
771 match pointee {
772 Some(old_pointee) if old_pointee != new_pointee => {
773 cx.tcx.dcx().fatal(format!(
774 "dig_scalar_pointee: unsupported Pointer with different \
775 pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
776 ));
777 }
778 _ => pointee = Some(new_pointee),
779 }
780 }
781 }
782 pointee.unwrap_or_else(|| {
783 bug!(
784 "field containing Pointer scalar at offset {:?} not found in {:#?}",
785 offset,
786 layout
787 )
788 })
789}
790
791fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
794 fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
795 assert_eq!(ty.size, Size::ZERO);
796 SpirvType::Adt {
797 def_id: def_id_for_spirv_type_adt(ty),
798 size: Some(Size::ZERO),
799 align: ty.align.abi,
800 field_types: &[],
801 field_offsets: &[],
802 field_names: None,
803 }
804 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
805 }
806 match ty.fields {
807 FieldsShape::Primitive => span_bug!(
808 span,
809 "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
810 ty
811 ),
812 FieldsShape::Union(_) => {
813 assert!(!ty.is_unsized(), "{ty:#?}");
814
815 let largest_case = (0..ty.fields.count())
821 .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
822 .max_by_key(|(_, case)| case.size);
823
824 if let Some((case_idx, case)) = largest_case {
825 if ty.align != case.align {
826 trans_struct_or_union(cx, span, ty, Some(case_idx))
828 } else {
829 assert_eq!(ty.size, case.size);
830 case.spirv_type(span, cx)
831 }
832 } else {
833 create_zst(cx, span, ty)
834 }
835 }
836 FieldsShape::Array { stride, count } => {
837 let element_type = ty.field(cx, 0).spirv_type(span, cx);
838 if ty.is_unsized() {
839 assert_eq!(count, 0);
842 SpirvType::RuntimeArray {
843 element: element_type,
844 }
845 .def(span, cx)
846 } else if count == 0 {
847 create_zst(cx, span, ty)
849 } else {
850 let count_const = cx.constant_u32(span, count as u32);
851 let element_spv = cx.lookup_type(element_type);
852 let stride_spv = element_spv
853 .sizeof(cx)
854 .expect("Unexpected unsized type in sized FieldsShape::Array")
855 .align_to(element_spv.alignof(cx));
856 assert_eq!(stride_spv, stride);
857 SpirvType::Array {
858 element: element_type,
859 count: count_const,
860 }
861 .def(span, cx)
862 }
863 }
864 FieldsShape::Arbitrary {
865 offsets: _,
866 memory_index: _,
867 } => trans_struct_or_union(cx, span, ty, None),
868 }
869}
870
871#[cfg_attr(
872 not(rustc_codegen_spirv_disable_pqp_cg_ssa),
873 expect(
874 unused,
875 reason = "actually used from \
876 `<rustc_codegen_ssa::traits::ConstCodegenMethods for CodegenCx<'_>>::const_struct`, \
877 but `rustc_codegen_ssa` being `pqp_cg_ssa` makes that trait unexported"
878 )
879)]
880pub fn auto_struct_layout(
882 cx: &CodegenCx<'_>,
883 field_types: &[Word],
884) -> (Vec<Size>, Option<Size>, Align) {
885 let mut field_offsets = Vec::with_capacity(field_types.len());
887 let mut offset = Some(Size::ZERO);
888 let mut max_align = Align::from_bytes(0).unwrap();
889 for &field_type in field_types {
890 let spirv_type = cx.lookup_type(field_type);
891 let field_size = spirv_type.sizeof(cx);
892 let field_align = spirv_type.alignof(cx);
893 let this_offset = offset
894 .expect("Unsized values can only be the last field in a struct")
895 .align_to(field_align);
896
897 field_offsets.push(this_offset);
898 if field_align > max_align {
899 max_align = field_align;
900 }
901 offset = field_size.map(|size| this_offset + size);
902 }
903 (field_offsets, offset, max_align)
904}
905
906fn trans_struct_or_union<'tcx>(
908 cx: &CodegenCx<'tcx>,
909 span: Span,
910 ty: TyAndLayout<'tcx>,
911 union_case: Option<FieldIdx>,
912) -> Word {
913 let size = if ty.is_unsized() { None } else { Some(ty.size) };
914 let align = ty.align.abi;
915 let mut field_types = Vec::new();
917 let mut field_offsets = Vec::new();
918 let mut field_names = Vec::new();
919 for i in ty.fields.index_by_increasing_offset() {
920 if let Some(expected_field_idx) = union_case
921 && i != expected_field_idx.as_usize()
922 {
923 continue;
924 }
925
926 let field_ty = ty.field(cx, i);
927 field_types.push(field_ty.spirv_type(span, cx));
928 let offset = ty.fields.offset(i);
929 field_offsets.push(offset);
930 if let Variants::Single { index } = ty.variants {
931 if let TyKind::Adt(adt, _) = ty.ty.kind() {
932 let field = &adt.variants()[index].fields[FieldIdx::new(i)];
933 field_names.push(field.name);
934 } else {
935 field_names.push(Symbol::intern(&format!("{i}")));
937 }
938 } else {
939 if let TyKind::Adt(_, _) = ty.ty.kind() {
940 } else {
941 span_bug!(span, "Variants::Multiple not TyKind::Adt");
942 }
943 if i == 0 {
944 field_names.push(cx.sym.discriminant);
945 } else {
946 cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
947 }
948 };
949 }
950 SpirvType::Adt {
951 def_id: def_id_for_spirv_type_adt(ty),
952 size,
953 align,
954 field_types: &field_types,
955 field_offsets: &field_offsets,
956 field_names: Some(&field_names),
957 }
958 .def_with_name(cx, span, TyLayoutNameKey::from(ty))
959}
960
961fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
965 match *layout.ty.kind() {
966 TyKind::Adt(def, _) => Some(def.did()),
967 TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
968 Some(def_id)
969 }
970 _ => None,
971 }
972}
973
974#[derive(Copy, Clone, PartialEq, Eq, Hash)]
977pub struct TyLayoutNameKey<'tcx> {
978 ty: Ty<'tcx>,
979 variant: Option<VariantIdx>,
980}
981
982impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
983 fn from(layout: TyAndLayout<'tcx>) -> Self {
984 TyLayoutNameKey {
985 ty: layout.ty,
986 variant: match layout.variants {
987 Variants::Single { index } => Some(index),
988 _ => None,
989 },
990 }
991 }
992}
993
994impl fmt::Display for TyLayoutNameKey<'_> {
995 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
996 write!(f, "{}", self.ty)?;
997 if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
998 && def.is_enum()
999 && !def.variants().is_empty()
1000 {
1001 write!(f, "::{}", def.variants()[index].name)?;
1002 }
1003 if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
1004 write!(f, "::{}", CoroutineArgs::variant_name(index))?;
1005 }
1006 Ok(())
1007 }
1008}
1009
1010fn trans_intrinsic_type<'tcx>(
1011 cx: &CodegenCx<'tcx>,
1012 span: Span,
1013 ty: TyAndLayout<'tcx>,
1014 args: GenericArgsRef<'tcx>,
1015 intrinsic_type_attr: IntrinsicType,
1016) -> Result<Word, ErrorGuaranteed> {
1017 match intrinsic_type_attr {
1018 IntrinsicType::GenericImageType => {
1019 if ty.size != Size::from_bytes(4) {
1021 return Err(cx
1022 .tcx
1023 .dcx()
1024 .err("#[spirv(generic_image)] type must have size 4"));
1025 }
1026
1027 let sampled_type = match args.type_at(0).kind() {
1040 TyKind::Int(int) => match int {
1041 IntTy::Isize => {
1042 SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, true)
1043 .def(span, cx)
1044 }
1045 IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
1046 IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
1047 IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
1048 IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
1049 IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
1050 },
1051 TyKind::Uint(uint) => match uint {
1052 UintTy::Usize => {
1053 SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, false)
1054 .def(span, cx)
1055 }
1056 UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
1057 UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
1058 UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
1059 UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
1060 UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
1061 },
1062 TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
1063 TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
1064 _ => {
1065 return Err(cx
1066 .tcx
1067 .dcx()
1068 .span_err(span, "Invalid sampled type to `Image`."));
1069 }
1070 };
1071
1072 trait FromScalarInt: Sized {
1081 fn from_scalar_int(n: ScalarInt) -> Option<Self>;
1082 }
1083
1084 impl FromScalarInt for u32 {
1085 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1086 Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
1087 }
1088 }
1089
1090 impl FromScalarInt for Dim {
1091 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1092 Dim::from_u32(u32::from_scalar_int(n)?)
1093 }
1094 }
1095
1096 impl FromScalarInt for ImageFormat {
1097 fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1098 ImageFormat::from_u32(u32::from_scalar_int(n)?)
1099 }
1100 }
1101
1102 fn const_int_value<'tcx, P: FromScalarInt>(
1103 cx: &CodegenCx<'tcx>,
1104 const_: Const<'tcx>,
1105 ) -> Result<P, ErrorGuaranteed> {
1106 let ty::Value {
1107 ty: const_ty,
1108 valtree: const_val,
1109 } = const_.to_value();
1110 assert!(const_ty.is_integral());
1111 const_val
1112 .try_to_scalar_int()
1113 .and_then(P::from_scalar_int)
1114 .ok_or_else(|| {
1115 cx.tcx
1116 .dcx()
1117 .err(format!("invalid value for Image const generic: {const_}"))
1118 })
1119 }
1120
1121 let dim = const_int_value(cx, args.const_at(1))?;
1122 let depth = const_int_value(cx, args.const_at(2))?;
1123 let arrayed = const_int_value(cx, args.const_at(3))?;
1124 let multisampled = const_int_value(cx, args.const_at(4))?;
1125 let sampled = const_int_value(cx, args.const_at(5))?;
1126 let image_format = const_int_value(cx, args.const_at(6))?;
1127
1128 let ty = SpirvType::Image {
1129 sampled_type,
1130 dim,
1131 depth,
1132 arrayed,
1133 multisampled,
1134 sampled,
1135 image_format,
1136 };
1137 Ok(ty.def(span, cx))
1138 }
1139 IntrinsicType::Sampler => {
1140 if ty.size != Size::from_bytes(4) {
1142 return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
1143 }
1144 Ok(SpirvType::Sampler.def(span, cx))
1145 }
1146 IntrinsicType::AccelerationStructureKhr => {
1147 Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
1148 }
1149 IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
1150 IntrinsicType::SampledImage => {
1151 if ty.size != Size::from_bytes(4) {
1153 return Err(cx
1154 .tcx
1155 .dcx()
1156 .err("#[spirv(sampled_image)] type must have size 4"));
1157 }
1158
1159 if let Some(image_ty) = args.types().next() {
1162 let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
1164 Ok(SpirvType::SampledImage { image_type }.def(span, cx))
1165 } else {
1166 Err(cx
1167 .tcx
1168 .dcx()
1169 .err("#[spirv(sampled_image)] type must have a generic image type"))
1170 }
1171 }
1172 IntrinsicType::RuntimeArray => {
1173 if ty.size != Size::from_bytes(4) {
1174 return Err(cx
1175 .tcx
1176 .dcx()
1177 .err("#[spirv(runtime_array)] type must have size 4"));
1178 }
1179
1180 if let Some(elem_ty) = args.types().next() {
1183 Ok(SpirvType::RuntimeArray {
1184 element: cx.layout_of(elem_ty).spirv_type(span, cx),
1185 }
1186 .def(span, cx))
1187 } else {
1188 Err(cx
1189 .tcx
1190 .dcx()
1191 .err("#[spirv(runtime_array)] type must have a generic element type"))
1192 }
1193 }
1194 IntrinsicType::TypedBuffer => {
1195 if ty.size != Size::from_bytes(4) {
1196 return Err(cx
1197 .tcx
1198 .sess
1199 .dcx()
1200 .err("#[spirv(typed_buffer)] type must have size 4"));
1201 }
1202
1203 if let Some(data_ty) = args.types().next() {
1206 Ok(SpirvType::InterfaceBlock {
1211 inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
1212 }
1213 .def(span, cx))
1214 } else {
1215 Err(cx
1216 .tcx
1217 .sess
1218 .dcx()
1219 .err("#[spirv(typed_buffer)] type must have a generic data type"))
1220 }
1221 }
1222 IntrinsicType::Matrix => {
1223 let span = def_id_for_spirv_type_adt(ty)
1224 .map(|did| cx.tcx.def_span(did))
1225 .expect("#[spirv(matrix)] must be added to a type which has DefId");
1226
1227 let field_types = (0..ty.fields.count())
1228 .map(|i| ty.field(cx, i).spirv_type(span, cx))
1229 .collect::<Vec<_>>();
1230 if field_types.len() < 2 {
1231 return Err(cx
1232 .tcx
1233 .dcx()
1234 .span_err(span, "#[spirv(matrix)] type must have at least two fields"));
1235 }
1236 let elem_type = field_types[0];
1237 if !field_types.iter().all(|&ty| ty == elem_type) {
1238 return Err(cx.tcx.dcx().span_err(
1239 span,
1240 "#[spirv(matrix)] type fields must all be the same type",
1241 ));
1242 }
1243 match cx.lookup_type(elem_type) {
1244 SpirvType::Vector { .. } => (),
1245 ty => {
1246 return Err(cx
1247 .tcx
1248 .dcx()
1249 .struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
1250 .with_note(format!("field type is {}", ty.debug(elem_type, cx)))
1251 .emit());
1252 }
1253 }
1254
1255 Ok(SpirvType::Matrix {
1256 element: elem_type,
1257 count: field_types.len() as u32,
1258 }
1259 .def(span, cx))
1260 }
1261 }
1262}