1use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11    Align, BackendRepr, FieldIdx, FieldsShape, Primitive, Scalar, Size, VariantIdx, Variants,
12};
13use rustc_data_structures::fx::FxHashMap;
14use rustc_errors::ErrorGuaranteed;
15use rustc_index::Idx;
16use rustc_middle::query::Providers;
17use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
18use rustc_middle::ty::{
19    self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
20    TyKind, UintTy,
21};
22use rustc_middle::ty::{GenericArgsRef, ScalarInt};
23use rustc_middle::{bug, span_bug};
24use rustc_span::DUMMY_SP;
25use rustc_span::def_id::DefId;
26use rustc_span::{Span, Symbol};
27use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
28use std::cell::RefCell;
29use std::collections::hash_map::Entry;
30use std::fmt;
31
32pub(crate) fn provide(providers: &mut Providers) {
33    providers.fn_sig = |tcx, def_id| {
43        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_sig)(tcx, def_id);
46        result.map_bound(|outer| {
47            outer.map_bound(|mut inner| {
48                if let Abi::C { .. } = inner.abi {
49                    inner.abi = Abi::Unadjusted;
50                }
51                inner
52            })
53        })
54    };
55
56    fn readjust_fn_abi<'tcx>(
61        tcx: TyCtxt<'tcx>,
62        fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
63    ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
64        let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
65            let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _, _| ArgAttributes::new());
66            arg.make_direct_deprecated();
69
70            if let PassMode::Pair(..) = arg.mode {
73                if let TyKind::Adt(..) = arg.layout.ty.kind() {
75                    arg.mode = PassMode::Direct(ArgAttributes::new());
76                }
77            }
78
79            if arg.layout.is_zst() {
81                arg.mode = PassMode::Ignore;
82            }
83
84            arg
85        };
86        tcx.arena.alloc(FnAbi {
87            args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
88            ret: readjust_arg_abi(&fn_abi.ret),
89
90            c_variadic: fn_abi.c_variadic,
94            fixed_count: fn_abi.fixed_count,
95            conv: fn_abi.conv,
96            can_unwind: fn_abi.can_unwind,
97        })
98    }
99    providers.fn_abi_of_fn_ptr = |tcx, key| {
100        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_fn_ptr)(tcx, key);
101        Ok(readjust_fn_abi(tcx, result?))
102    };
103    providers.fn_abi_of_instance = |tcx, key| {
104        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_instance)(tcx, key);
105        Ok(readjust_fn_abi(tcx, result?))
106    };
107
108    providers.check_mono_item = |_, _| {};
115}
116
117#[derive(Default)]
123pub struct RecursivePointeeCache<'tcx> {
124    map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
125}
126
127impl<'tcx> RecursivePointeeCache<'tcx> {
128    fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
129        match self.map.borrow_mut().entry(pointee) {
130            Entry::Vacant(entry) => {
133                entry.insert(PointeeDefState::Defining);
134                None
135            }
136            Entry::Occupied(mut entry) => match *entry.get() {
137                PointeeDefState::Defining => {
141                    let new_id = cx.emit_global().id();
142                    cx.emit_global()
146                        .type_forward_pointer(new_id, StorageClass::Generic);
147                    entry.insert(PointeeDefState::DefiningWithForward(new_id));
148                    cx.zombie_with_span(
149                        new_id,
150                        span,
151                        "cannot create self-referential types, even through pointers",
152                    );
153                    Some(new_id)
154                }
155                PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
159            },
160        }
161    }
162
163    fn end(
164        &self,
165        cx: &CodegenCx<'tcx>,
166        span: Span,
167        pointee: PointeeTy<'tcx>,
168        pointee_spv: Word,
169    ) -> Word {
170        match self.map.borrow_mut().entry(pointee) {
171            Entry::Vacant(_) => {
173                span_bug!(span, "RecursivePointeeCache::end should always have entry")
174            }
175            Entry::Occupied(mut entry) => match *entry.get() {
176                PointeeDefState::Defining => {
179                    let id = SpirvType::Pointer {
180                        pointee: pointee_spv,
181                    }
182                    .def(span, cx);
183                    entry.insert(PointeeDefState::Defined(id));
184                    id
185                }
186                PointeeDefState::DefiningWithForward(id) => {
189                    entry.insert(PointeeDefState::Defined(id));
190                    SpirvType::Pointer {
191                        pointee: pointee_spv,
192                    }
193                    .def_with_id(cx, span, id)
194                }
195                PointeeDefState::Defined(_) => {
196                    span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
197                }
198            },
199        }
200    }
201}
202
203#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
204enum PointeeTy<'tcx> {
205    Ty(TyAndLayout<'tcx>),
206    Fn(PolyFnSig<'tcx>),
207}
208
209impl fmt::Display for PointeeTy<'_> {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        match self {
212            PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
213            PointeeTy::Fn(ty) => write!(f, "{ty}"),
214        }
215    }
216}
217
218enum PointeeDefState {
219    Defining,
220    DefiningWithForward(Word),
221    Defined(Word),
222}
223
224pub trait ConvSpirvType<'tcx> {
227    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
228}
229
230impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
231    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
232        match *self {
233            PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
234            PointeeTy::Fn(ty) => cx
235                .fn_abi_of_fn_ptr(ty, ty::List::empty())
236                .spirv_type(span, cx),
237        }
238    }
239}
240
241impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
242    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
243        let mut argument_types = Vec::new();
245
246        let return_type = match self.ret.mode {
247            PassMode::Ignore => SpirvType::Void.def(span, cx),
248            PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
249            PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
250                span,
251                "query hooks should've made this `PassMode` impossible: {:#?}",
252                self.ret
253            ),
254        };
255
256        for arg in self.args.iter() {
257            let arg_type = match arg.mode {
258                PassMode::Ignore => continue,
259                PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
260                PassMode::Pair(_, _) => {
261                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
262                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
263                    continue;
264                }
265                PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
266                    span,
267                    "query hooks should've made this `PassMode` impossible: {:#?}",
268                    arg
269                ),
270            };
271            argument_types.push(arg_type);
272        }
273
274        SpirvType::Function {
275            return_type,
276            arguments: &argument_types,
277        }
278        .def(span, cx)
279    }
280}
281
282impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
283    fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
284        if let TyKind::Adt(adt, args) = *self.ty.kind() {
285            if span == DUMMY_SP {
286                span = cx.tcx.def_span(adt.did());
287            }
288
289            let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
290
291            if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
292                && let Ok(spirv_type) =
293                    trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
294            {
295                return spirv_type;
296            }
297        }
298
299        match self.backend_repr {
303            _ if self.uninhabited => SpirvType::Adt {
304                def_id: def_id_for_spirv_type_adt(*self),
305                size: Some(Size::ZERO),
306                align: Align::from_bytes(0).unwrap(),
307                field_types: &[],
308                field_offsets: &[],
309                field_names: None,
310            }
311            .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
312            BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
313            BackendRepr::ScalarPair(a, b) => {
314                let mut non_zst_fields = (0..self.fields.count())
330                    .map(|i| (i, self.field(cx, i)))
331                    .filter(|(_, field)| !field.is_zst());
332                let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
333                    (Some(field), None) => Some(field),
334                    _ => None,
335                };
336                if let Some((i, field)) = sole_non_zst_field {
337                    if self.fields.offset(i) == Size::ZERO
340                        && field.size == self.size
341                        && field.align.abi == self.align.abi
342                        && field.backend_repr.eq_up_to_validity(&self.backend_repr)
343                    {
344                        return field.spirv_type(span, cx);
345                    }
346                }
347
348                let a_offset = Size::ZERO;
351                let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
352                let a = trans_scalar(cx, span, *self, a, a_offset);
353                let b = trans_scalar(cx, span, *self, b, b_offset);
354                let size = if self.is_unsized() {
355                    None
356                } else {
357                    Some(self.size)
358                };
359                let mut field_names = Vec::new();
361                if let TyKind::Adt(adt, _) = self.ty.kind()
362                    && let Variants::Single { index } = self.variants
363                {
364                    for i in self.fields.index_by_increasing_offset() {
365                        let field = &adt.variants()[index].fields[FieldIdx::new(i)];
366                        field_names.push(field.name);
367                    }
368                }
369                SpirvType::Adt {
370                    def_id: def_id_for_spirv_type_adt(*self),
371                    size,
372                    align: self.align.abi,
373                    field_types: &[a, b],
374                    field_offsets: &[a_offset, b_offset],
375                    field_names: if field_names.len() == 2 {
376                        Some(&field_names)
377                    } else {
378                        None
379                    },
380                }
381                .def_with_name(cx, span, TyLayoutNameKey::from(*self))
382            }
383            BackendRepr::SimdVector { element, count } => {
384                let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
385                SpirvType::Vector {
386                    element: elem_spirv,
387                    count: count as u32,
388                    size: self.size,
389                    align: self.align.abi,
390                }
391                .def(span, cx)
392            }
393            BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
394        }
395    }
396}
397
398pub fn scalar_pair_element_backend_type<'tcx>(
401    cx: &CodegenCx<'tcx>,
402    span: Span,
403    ty: TyAndLayout<'tcx>,
404    index: usize,
405) -> Word {
406    let [a, b] = match ty.layout.backend_repr() {
407        BackendRepr::ScalarPair(a, b) => [a, b],
408        other => span_bug!(
409            span,
410            "scalar_pair_element_backend_type invalid abi: {:?}",
411            other
412        ),
413    };
414    let offset = match index {
415        0 => Size::ZERO,
416        1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
417        _ => unreachable!(),
418    };
419    trans_scalar(cx, span, ty, [a, b][index], offset)
420}
421
422fn trans_scalar<'tcx>(
430    cx: &CodegenCx<'tcx>,
431    span: Span,
432    ty: TyAndLayout<'tcx>,
433    scalar: Scalar,
434    offset: Size,
435) -> Word {
436    if scalar.is_bool() {
437        return SpirvType::Bool.def(span, cx);
438    }
439
440    match scalar.primitive() {
441        Primitive::Int(int_kind, signedness) => {
442            SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
443        }
444        Primitive::Float(float_kind) => {
445            SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
446        }
447        Primitive::Pointer(_) => {
448            let pointee_ty = dig_scalar_pointee(cx, ty, offset);
449            if let Some(predefined_result) = cx
452                .type_cache
453                .recursive_pointee_cache
454                .begin(cx, span, pointee_ty)
455            {
456                predefined_result
457            } else {
458                let pointee = pointee_ty.spirv_type(span, cx);
459                cx.type_cache
460                    .recursive_pointee_cache
461                    .end(cx, span, pointee_ty, pointee)
462            }
463        }
464    }
465}
466
467fn dig_scalar_pointee<'tcx>(
478    cx: &CodegenCx<'tcx>,
479    layout: TyAndLayout<'tcx>,
480    offset: Size,
481) -> PointeeTy<'tcx> {
482    if let FieldsShape::Primitive = layout.fields {
483        assert_eq!(offset, Size::ZERO);
484        let pointee = match *layout.ty.kind() {
485            TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
486                PointeeTy::Ty(cx.layout_of(pointee_ty))
487            }
488            TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
489            _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
490        };
491        return pointee;
492    }
493
494    let all_fields = (match &layout.variants {
495        Variants::Empty => 0..0,
496        Variants::Multiple { variants, .. } => 0..variants.len(),
497        Variants::Single { index } => {
498            let i = index.as_usize();
499            i..i + 1
500        }
501    })
502    .flat_map(|variant_idx| {
503        let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
504        (0..variant.fields.count()).map(move |field_idx| {
505            (
506                variant.field(cx, field_idx),
507                variant.fields.offset(field_idx),
508            )
509        })
510    });
511
512    let mut pointee = None;
513    for (field, field_offset) in all_fields {
514        if field.is_zst() {
515            continue;
516        }
517        if (field_offset..field_offset + field.size).contains(&offset) {
518            let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
519            match pointee {
520                Some(old_pointee) if old_pointee != new_pointee => {
521                    cx.tcx.dcx().fatal(format!(
522                        "dig_scalar_pointee: unsupported Pointer with different \
523                         pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
524                    ));
525                }
526                _ => pointee = Some(new_pointee),
527            }
528        }
529    }
530    pointee.unwrap_or_else(|| {
531        bug!(
532            "field containing Pointer scalar at offset {:?} not found in {:#?}",
533            offset,
534            layout
535        )
536    })
537}
538
539fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
542    fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
543        assert_eq!(ty.size, Size::ZERO);
544        SpirvType::Adt {
545            def_id: def_id_for_spirv_type_adt(ty),
546            size: Some(Size::ZERO),
547            align: ty.align.abi,
548            field_types: &[],
549            field_offsets: &[],
550            field_names: None,
551        }
552        .def_with_name(cx, span, TyLayoutNameKey::from(ty))
553    }
554    match ty.fields {
555        FieldsShape::Primitive => span_bug!(
556            span,
557            "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
558            ty
559        ),
560        FieldsShape::Union(_) => {
561            assert!(!ty.is_unsized(), "{ty:#?}");
562
563            let largest_case = (0..ty.fields.count())
569                .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
570                .max_by_key(|(_, case)| case.size);
571
572            if let Some((case_idx, case)) = largest_case {
573                if ty.align != case.align {
574                    trans_struct_or_union(cx, span, ty, Some(case_idx))
576                } else {
577                    assert_eq!(ty.size, case.size);
578                    case.spirv_type(span, cx)
579                }
580            } else {
581                create_zst(cx, span, ty)
582            }
583        }
584        FieldsShape::Array { stride, count } => {
585            let element_type = ty.field(cx, 0).spirv_type(span, cx);
586            if ty.is_unsized() {
587                assert_eq!(count, 0);
590                SpirvType::RuntimeArray {
591                    element: element_type,
592                }
593                .def(span, cx)
594            } else if count == 0 {
595                create_zst(cx, span, ty)
597            } else {
598                let count_const = cx.constant_u32(span, count as u32);
599                let element_spv = cx.lookup_type(element_type);
600                let stride_spv = element_spv
601                    .sizeof(cx)
602                    .expect("Unexpected unsized type in sized FieldsShape::Array")
603                    .align_to(element_spv.alignof(cx));
604                assert_eq!(stride_spv, stride);
605                SpirvType::Array {
606                    element: element_type,
607                    count: count_const,
608                }
609                .def(span, cx)
610            }
611        }
612        FieldsShape::Arbitrary {
613            offsets: _,
614            memory_index: _,
615        } => trans_struct_or_union(cx, span, ty, None),
616    }
617}
618
619#[cfg_attr(
620    not(rustc_codegen_spirv_disable_pqp_cg_ssa),
621    expect(
622        unused,
623        reason = "actually used from \
624            `<rustc_codegen_ssa::traits::ConstCodegenMethods for CodegenCx<'_>>::const_struct`, \
625            but `rustc_codegen_ssa` being `pqp_cg_ssa` makes that trait unexported"
626    )
627)]
628pub fn auto_struct_layout(
630    cx: &CodegenCx<'_>,
631    field_types: &[Word],
632) -> (Vec<Size>, Option<Size>, Align) {
633    let mut field_offsets = Vec::with_capacity(field_types.len());
635    let mut offset = Some(Size::ZERO);
636    let mut max_align = Align::from_bytes(0).unwrap();
637    for &field_type in field_types {
638        let spirv_type = cx.lookup_type(field_type);
639        let field_size = spirv_type.sizeof(cx);
640        let field_align = spirv_type.alignof(cx);
641        let this_offset = offset
642            .expect("Unsized values can only be the last field in a struct")
643            .align_to(field_align);
644
645        field_offsets.push(this_offset);
646        if field_align > max_align {
647            max_align = field_align;
648        }
649        offset = field_size.map(|size| this_offset + size);
650    }
651    (field_offsets, offset, max_align)
652}
653
654fn trans_struct_or_union<'tcx>(
656    cx: &CodegenCx<'tcx>,
657    span: Span,
658    ty: TyAndLayout<'tcx>,
659    union_case: Option<FieldIdx>,
660) -> Word {
661    let size = if ty.is_unsized() { None } else { Some(ty.size) };
662    let align = ty.align.abi;
663    let mut field_types = Vec::new();
665    let mut field_offsets = Vec::new();
666    let mut field_names = Vec::new();
667    for i in ty.fields.index_by_increasing_offset() {
668        if let Some(expected_field_idx) = union_case
669            && i != expected_field_idx.as_usize()
670        {
671            continue;
672        }
673
674        let field_ty = ty.field(cx, i);
675        field_types.push(field_ty.spirv_type(span, cx));
676        let offset = ty.fields.offset(i);
677        field_offsets.push(offset);
678        if let Variants::Single { index } = ty.variants {
679            if let TyKind::Adt(adt, _) = ty.ty.kind() {
680                let field = &adt.variants()[index].fields[FieldIdx::new(i)];
681                field_names.push(field.name);
682            } else {
683                field_names.push(Symbol::intern(&format!("{i}")));
685            }
686        } else {
687            if let TyKind::Adt(_, _) = ty.ty.kind() {
688            } else {
689                span_bug!(span, "Variants::Multiple not TyKind::Adt");
690            }
691            if i == 0 {
692                field_names.push(cx.sym.discriminant);
693            } else {
694                cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
695            }
696        };
697    }
698    SpirvType::Adt {
699        def_id: def_id_for_spirv_type_adt(ty),
700        size,
701        align,
702        field_types: &field_types,
703        field_offsets: &field_offsets,
704        field_names: Some(&field_names),
705    }
706    .def_with_name(cx, span, TyLayoutNameKey::from(ty))
707}
708
709fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
713    match *layout.ty.kind() {
714        TyKind::Adt(def, _) => Some(def.did()),
715        TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
716            Some(def_id)
717        }
718        _ => None,
719    }
720}
721
722fn span_for_spirv_type_adt(cx: &CodegenCx<'_>, layout: TyAndLayout<'_>) -> Option<Span> {
723    def_id_for_spirv_type_adt(layout).map(|did| cx.tcx.def_span(did))
724}
725
726#[derive(Copy, Clone, PartialEq, Eq, Hash)]
729pub struct TyLayoutNameKey<'tcx> {
730    ty: Ty<'tcx>,
731    variant: Option<VariantIdx>,
732}
733
734impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
735    fn from(layout: TyAndLayout<'tcx>) -> Self {
736        TyLayoutNameKey {
737            ty: layout.ty,
738            variant: match layout.variants {
739                Variants::Single { index } => Some(index),
740                _ => None,
741            },
742        }
743    }
744}
745
746impl fmt::Display for TyLayoutNameKey<'_> {
747    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
748        write!(f, "{}", self.ty)?;
749        if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
750            && def.is_enum()
751            && !def.variants().is_empty()
752        {
753            write!(f, "::{}", def.variants()[index].name)?;
754        }
755        if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
756            write!(f, "::{}", CoroutineArgs::variant_name(index))?;
757        }
758        Ok(())
759    }
760}
761
762fn trans_intrinsic_type<'tcx>(
763    cx: &CodegenCx<'tcx>,
764    span: Span,
765    ty: TyAndLayout<'tcx>,
766    args: GenericArgsRef<'tcx>,
767    intrinsic_type_attr: IntrinsicType,
768) -> Result<Word, ErrorGuaranteed> {
769    match intrinsic_type_attr {
770        IntrinsicType::GenericImageType => {
771            if ty.size != Size::from_bytes(4) {
773                return Err(cx
774                    .tcx
775                    .dcx()
776                    .err("#[spirv(generic_image)] type must have size 4"));
777            }
778
779            let sampled_type = match args.type_at(0).kind() {
792                TyKind::Int(int) => match int {
793                    IntTy::Isize => {
794                        SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, true)
795                            .def(span, cx)
796                    }
797                    IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
798                    IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
799                    IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
800                    IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
801                    IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
802                },
803                TyKind::Uint(uint) => match uint {
804                    UintTy::Usize => {
805                        SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, false)
806                            .def(span, cx)
807                    }
808                    UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
809                    UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
810                    UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
811                    UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
812                    UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
813                },
814                TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
815                TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
816                _ => {
817                    return Err(cx
818                        .tcx
819                        .dcx()
820                        .span_err(span, "Invalid sampled type to `Image`."));
821                }
822            };
823
824            trait FromScalarInt: Sized {
833                fn from_scalar_int(n: ScalarInt) -> Option<Self>;
834            }
835
836            impl FromScalarInt for u32 {
837                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
838                    Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
839                }
840            }
841
842            impl FromScalarInt for Dim {
843                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
844                    Dim::from_u32(u32::from_scalar_int(n)?)
845                }
846            }
847
848            impl FromScalarInt for ImageFormat {
849                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
850                    ImageFormat::from_u32(u32::from_scalar_int(n)?)
851                }
852            }
853
854            fn const_int_value<'tcx, P: FromScalarInt>(
855                cx: &CodegenCx<'tcx>,
856                const_: Const<'tcx>,
857            ) -> Result<P, ErrorGuaranteed> {
858                let ty::Value {
859                    ty: const_ty,
860                    valtree: const_val,
861                } = const_.to_value();
862                assert!(const_ty.is_integral());
863                const_val
864                    .try_to_scalar_int()
865                    .and_then(P::from_scalar_int)
866                    .ok_or_else(|| {
867                        cx.tcx
868                            .dcx()
869                            .err(format!("invalid value for Image const generic: {const_}"))
870                    })
871            }
872
873            let dim = const_int_value(cx, args.const_at(1))?;
874            let depth = const_int_value(cx, args.const_at(2))?;
875            let arrayed = const_int_value(cx, args.const_at(3))?;
876            let multisampled = const_int_value(cx, args.const_at(4))?;
877            let sampled = const_int_value(cx, args.const_at(5))?;
878            let image_format = const_int_value(cx, args.const_at(6))?;
879
880            let ty = SpirvType::Image {
881                sampled_type,
882                dim,
883                depth,
884                arrayed,
885                multisampled,
886                sampled,
887                image_format,
888            };
889            Ok(ty.def(span, cx))
890        }
891        IntrinsicType::Sampler => {
892            if ty.size != Size::from_bytes(4) {
894                return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
895            }
896            Ok(SpirvType::Sampler.def(span, cx))
897        }
898        IntrinsicType::AccelerationStructureKhr => {
899            Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
900        }
901        IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
902        IntrinsicType::SampledImage => {
903            if ty.size != Size::from_bytes(4) {
905                return Err(cx
906                    .tcx
907                    .dcx()
908                    .err("#[spirv(sampled_image)] type must have size 4"));
909            }
910
911            if let Some(image_ty) = args.types().next() {
914                let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
916                Ok(SpirvType::SampledImage { image_type }.def(span, cx))
917            } else {
918                Err(cx
919                    .tcx
920                    .dcx()
921                    .err("#[spirv(sampled_image)] type must have a generic image type"))
922            }
923        }
924        IntrinsicType::RuntimeArray => {
925            if ty.size != Size::from_bytes(4) {
926                return Err(cx
927                    .tcx
928                    .dcx()
929                    .err("#[spirv(runtime_array)] type must have size 4"));
930            }
931
932            if let Some(elem_ty) = args.types().next() {
935                Ok(SpirvType::RuntimeArray {
936                    element: cx.layout_of(elem_ty).spirv_type(span, cx),
937                }
938                .def(span, cx))
939            } else {
940                Err(cx
941                    .tcx
942                    .dcx()
943                    .err("#[spirv(runtime_array)] type must have a generic element type"))
944            }
945        }
946        IntrinsicType::TypedBuffer => {
947            if ty.size != Size::from_bytes(4) {
948                return Err(cx
949                    .tcx
950                    .sess
951                    .dcx()
952                    .err("#[spirv(typed_buffer)] type must have size 4"));
953            }
954
955            if let Some(data_ty) = args.types().next() {
958                Ok(SpirvType::InterfaceBlock {
963                    inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
964                }
965                .def(span, cx))
966            } else {
967                Err(cx
968                    .tcx
969                    .sess
970                    .dcx()
971                    .err("#[spirv(typed_buffer)] type must have a generic data type"))
972            }
973        }
974        IntrinsicType::Matrix => {
975            let span = span_for_spirv_type_adt(cx, ty).unwrap();
976            let err_attr_name = "`#[spirv(matrix)]`";
977            let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
978            match cx.lookup_type(element) {
979                SpirvType::Vector { .. } => (),
980                ty => {
981                    return Err(cx
982                        .tcx
983                        .dcx()
984                        .struct_span_err(
985                            span,
986                            format!("{err_attr_name} type fields must all be vectors"),
987                        )
988                        .with_note(format!("field type is {}", ty.debug(element, cx)))
989                        .emit());
990                }
991            }
992            Ok(SpirvType::Matrix { element, count }.def(span, cx))
993        }
994        IntrinsicType::Vector => {
995            let span = span_for_spirv_type_adt(cx, ty).unwrap();
996            let err_attr_name = "`#[spirv(vector)]`";
997            let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
998            match cx.lookup_type(element) {
999                SpirvType::Bool | SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1000                ty => {
1001                    return Err(cx
1002                        .tcx
1003                        .dcx()
1004                        .struct_span_err(
1005                            span,
1006                            format!(
1007                                "{err_attr_name} type fields must all be floats, integers or bools"
1008                            ),
1009                        )
1010                        .with_note(format!("field type is {}", ty.debug(element, cx)))
1011                        .emit());
1012                }
1013            }
1014            Ok(SpirvType::Vector {
1015                element,
1016                count,
1017                size: ty.size,
1018                align: ty.align.abi,
1019            }
1020            .def(span, cx))
1021        }
1022    }
1023}
1024
1025fn trans_glam_like_struct<'tcx>(
1028    cx: &CodegenCx<'tcx>,
1029    span: Span,
1030    ty: TyAndLayout<'tcx>,
1031    args: GenericArgsRef<'tcx>,
1032    err_attr_name: &str,
1033) -> Result<(Word, u32), ErrorGuaranteed> {
1034    let tcx = cx.tcx;
1035    if let Some(adt) = ty.ty.ty_adt_def()
1036        && adt.is_struct()
1037    {
1038        let (count, element) = adt
1039            .non_enum_variant()
1040            .fields
1041            .iter()
1042            .map(|f| f.ty(tcx, args))
1043            .dedup_with_count()
1044            .exactly_one()
1045            .map_err(|_e| {
1046                tcx.dcx().span_err(
1047                    span,
1048                    format!("{err_attr_name} member types must all be the same"),
1049                )
1050            })?;
1051
1052        let element = cx.layout_of(element);
1053        let element_word = element.spirv_type(span, cx);
1054        let count = u32::try_from(count)
1055            .ok()
1056            .filter(|count| 2 <= *count && *count <= 4)
1057            .ok_or_else(|| {
1058                tcx.dcx()
1059                    .span_err(span, format!("{err_attr_name} must have 2, 3 or 4 members"))
1060            })?;
1061
1062        for i in 0..ty.fields.count() {
1063            let expected = element.size.checked_mul(i as u64, cx).unwrap();
1064            let actual = ty.fields.offset(i);
1065            if actual != expected {
1066                let name: &str = adt
1067                    .non_enum_variant()
1068                    .fields
1069                    .get(FieldIdx::from(i))
1070                    .unwrap()
1071                    .name
1072                    .as_str();
1073                tcx.dcx().span_fatal(
1074                    span,
1075                    format!(
1076                        "Unexpected layout for {err_attr_name} annotated struct: \
1077                    Expected member `{name}` at offset {expected:?}, but was at {actual:?}"
1078                    ),
1079                )
1080            }
1081        }
1082
1083        Ok((element_word, count))
1084    } else {
1085        Err(tcx
1086            .dcx()
1087            .span_err(span, format!("{err_attr_name} type must be a struct")))
1088    }
1089}