rustc_codegen_spirv/
abi.rs

1//! This file is responsible for translation from rustc tys (`TyAndLayout`) to spir-v types. It's
2//! surprisingly difficult.
3
4use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11    Align, BackendRepr, FieldIdx, FieldsShape, HasDataLayout as _, LayoutData, Primitive,
12    ReprFlags, ReprOptions, Scalar, Size, TagEncoding, VariantIdx, Variants,
13};
14use rustc_data_structures::fx::FxHashMap;
15use rustc_errors::ErrorGuaranteed;
16use rustc_hashes::Hash64;
17use rustc_index::Idx;
18use rustc_middle::query::Providers;
19use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
20use rustc_middle::ty::{
21    self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
22    TyKind, UintTy,
23};
24use rustc_middle::ty::{GenericArgsRef, ScalarInt};
25use rustc_middle::{bug, span_bug};
26use rustc_span::DUMMY_SP;
27use rustc_span::def_id::DefId;
28use rustc_span::{Span, Symbol};
29use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
30use std::cell::RefCell;
31use std::collections::hash_map::Entry;
32use std::fmt;
33
34pub(crate) fn provide(providers: &mut Providers) {
35    // This is a lil weird: so, we obviously don't support C ABIs at all. However, libcore does declare some extern
36    // C functions:
37    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/library/core/src/slice/cmp.rs#L119
38    // However, those functions will be implemented by compiler-builtins:
39    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/library/core/src/lib.rs#L23-L27
40    // This theoretically then should be fine to leave as C, but, there's no backend hook for
41    // `FnAbi::adjust_for_cabi`, causing it to panic:
42    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/compiler/rustc_target/src/abi/call/mod.rs#L603
43    // So, treat any `extern "C"` functions as `extern "unadjusted"`, to be able to compile libcore with arch=spirv.
44    providers.fn_sig = |tcx, def_id| {
45        // We can't capture the old fn_sig and just call that, because fn_sig is a `fn`, not a `Fn`, i.e. it can't
46        // capture variables. Fortunately, the defaults are exposed (thanks rustdoc), so use that instead.
47        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_sig)(tcx, def_id);
48        result.map_bound(|outer| {
49            outer.map_bound(|mut inner| {
50                if let Abi::C { .. } = inner.abi {
51                    inner.abi = Abi::Unadjusted;
52                }
53                inner
54            })
55        })
56    };
57
58    // For the Rust ABI, `FnAbi` adjustments are backend-agnostic, but they will
59    // use features like `PassMode::Cast`, that are incompatible with SPIR-V.
60    // By hooking the queries computing `FnAbi`s, we can recompute the `FnAbi`
61    // from the return/args layouts, to e.g. prefer using `PassMode::Direct`.
62    fn readjust_fn_abi<'tcx>(
63        tcx: TyCtxt<'tcx>,
64        fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
65    ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
66        let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
67            let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _, _| ArgAttributes::new());
68            // FIXME: this is bad! https://github.com/rust-lang/rust/issues/115666
69            // <https://github.com/rust-lang/rust/commit/eaaa03faf77b157907894a4207d8378ecaec7b45>
70            arg.make_direct_deprecated();
71
72            // Avoid pointlessly passing ZSTs, just like the official Rust ABI.
73            if arg.layout.is_zst() {
74                arg.mode = PassMode::Ignore;
75            }
76
77            arg
78        };
79        tcx.arena.alloc(FnAbi {
80            args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
81            ret: readjust_arg_abi(&fn_abi.ret),
82
83            // FIXME(eddyb) validate some of these, and report errors - however,
84            // we can't just emit errors from here, since we have no `Span`, so
85            // we should have instead a check on MIR for e.g. C variadic calls.
86            c_variadic: fn_abi.c_variadic,
87            fixed_count: fn_abi.fixed_count,
88            conv: fn_abi.conv,
89            can_unwind: fn_abi.can_unwind,
90        })
91    }
92    providers.fn_abi_of_fn_ptr = |tcx, key| {
93        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_fn_ptr)(tcx, key);
94        Ok(readjust_fn_abi(tcx, result?))
95    };
96    providers.fn_abi_of_instance = |tcx, key| {
97        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.fn_abi_of_instance)(tcx, key);
98        Ok(readjust_fn_abi(tcx, result?))
99    };
100
101    // FIXME(eddyb) remove this by deriving `Clone` for `LayoutData` upstream.
102    fn clone_layout<FieldIdx: Idx, VariantIdx: Idx>(
103        layout: &LayoutData<FieldIdx, VariantIdx>,
104    ) -> LayoutData<FieldIdx, VariantIdx> {
105        let LayoutData {
106            ref fields,
107            ref variants,
108            backend_repr,
109            largest_niche,
110            uninhabited,
111            align,
112            size,
113            max_repr_align,
114            unadjusted_abi_align,
115            randomization_seed,
116        } = *layout;
117        LayoutData {
118            fields: match *fields {
119                FieldsShape::Primitive => FieldsShape::Primitive,
120                FieldsShape::Union(count) => FieldsShape::Union(count),
121                FieldsShape::Array { stride, count } => FieldsShape::Array { stride, count },
122                FieldsShape::Arbitrary {
123                    ref offsets,
124                    ref memory_index,
125                } => FieldsShape::Arbitrary {
126                    offsets: offsets.clone(),
127                    memory_index: memory_index.clone(),
128                },
129            },
130            variants: match *variants {
131                Variants::Empty => Variants::Empty,
132                Variants::Single { index } => Variants::Single { index },
133                Variants::Multiple {
134                    tag,
135                    ref tag_encoding,
136                    tag_field,
137                    ref variants,
138                } => Variants::Multiple {
139                    tag,
140                    tag_encoding: match *tag_encoding {
141                        TagEncoding::Direct => TagEncoding::Direct,
142                        TagEncoding::Niche {
143                            untagged_variant,
144                            ref niche_variants,
145                            niche_start,
146                        } => TagEncoding::Niche {
147                            untagged_variant,
148                            niche_variants: niche_variants.clone(),
149                            niche_start,
150                        },
151                    },
152                    tag_field,
153                    variants: variants.clone(),
154                },
155            },
156            backend_repr,
157            largest_niche,
158            uninhabited,
159            align,
160            size,
161            max_repr_align,
162            unadjusted_abi_align,
163            randomization_seed,
164        }
165    }
166
167    providers.layout_of = |tcx, key| {
168        // HACK(eddyb) to special-case any types at all, they must be normalized,
169        // but when normalization would be needed, `layout_of`'s default provider
170        // recurses (supposedly for caching reasons), i.e. its calls `layout_of`
171        // w/ the normalized type in input, which once again reaches this hook,
172        // without ever needing any explicit normalization here.
173        let ty = key.value;
174
175        // HACK(eddyb) bypassing upstream `#[repr(simd)]` changes (see also
176        // the later comment above `check_well_formed`, for more details).
177        let reimplement_old_style_repr_simd = match ty.kind() {
178            ty::Adt(def, args) if def.repr().simd() && !def.repr().packed() && def.is_struct() => {
179                Some(def.non_enum_variant()).and_then(|v| {
180                    let (count, e_ty) = v
181                        .fields
182                        .iter()
183                        .map(|f| f.ty(tcx, args))
184                        .dedup_with_count()
185                        .exactly_one()
186                        .ok()?;
187                    let e_len = u64::try_from(count).ok().filter(|&e_len| e_len > 1)?;
188                    Some((def, e_ty, e_len))
189                })
190            }
191            _ => None,
192        };
193
194        // HACK(eddyb) tweaked copy of the old upstream logic for `#[repr(simd)]`:
195        // https://github.com/rust-lang/rust/blob/1.86.0/compiler/rustc_ty_utils/src/layout.rs#L464-L590
196        if let Some((adt_def, e_ty, e_len)) = reimplement_old_style_repr_simd {
197            let cx = rustc_middle::ty::layout::LayoutCx::new(
198                tcx,
199                key.typing_env.with_post_analysis_normalized(tcx),
200            );
201            let dl = cx.data_layout();
202
203            // Compute the ABI of the element type:
204            let e_ly = cx.layout_of(e_ty)?;
205            let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
206                // This error isn't caught in typeck, e.g., if
207                // the element type of the vector is generic.
208                tcx.dcx().span_fatal(
209                    tcx.def_span(adt_def.did()),
210                    format!(
211                        "SIMD type `{ty}` with a non-primitive-scalar \
212                     (integer/float/pointer) element type `{}`",
213                        e_ly.ty
214                    ),
215                );
216            };
217
218            // Compute the size and alignment of the vector:
219            let size = e_ly.size.checked_mul(e_len, dl).unwrap();
220            let align = dl.llvmlike_vector_align(size);
221            let size = size.align_to(align.abi);
222
223            let layout = tcx.mk_layout(LayoutData {
224                variants: Variants::Single {
225                    index: rustc_abi::FIRST_VARIANT,
226                },
227                fields: FieldsShape::Array {
228                    stride: e_ly.size,
229                    count: e_len,
230                },
231                backend_repr: BackendRepr::SimdVector {
232                    element: e_repr,
233                    count: e_len,
234                },
235                largest_niche: e_ly.largest_niche,
236                uninhabited: false,
237                size,
238                align,
239                max_repr_align: None,
240                unadjusted_abi_align: align.abi,
241                randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
242            });
243
244            return Ok(TyAndLayout { ty, layout });
245        }
246
247        let TyAndLayout { ty, mut layout } =
248            (rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;
249
250        #[allow(clippy::match_like_matches_macro)]
251        let hide_niche = match ty.kind() {
252            ty::Bool => {
253                // HACK(eddyb) we can't bypass e.g. `Option<bool>` being a byte,
254                // due to `core` PR https://github.com/rust-lang/rust/pull/138881
255                // (which adds a new `transmute`, from `ControlFlow<bool>` to `u8`).
256                let libcore_needs_bool_niche = true;
257
258                !libcore_needs_bool_niche
259            }
260            _ => false,
261        };
262
263        if hide_niche {
264            layout = tcx.mk_layout(LayoutData {
265                largest_niche: None,
266                ..clone_layout(layout.0.0)
267            });
268        }
269
270        Ok(TyAndLayout { ty, layout })
271    };
272
273    // HACK(eddyb) work around https://github.com/rust-lang/rust/pull/129403
274    // banning "struct-style" `#[repr(simd)]` (in favor of "array-newtype-style"),
275    // by simply bypassing "type definition WF checks" for affected types, which:
276    // - can only really be sound for types with trivial field types, that are
277    //   either completely non-generic (covering most `#[repr(simd)]` `struct`s),
278    //   or *at most* one generic type parameter with no bounds/where clause
279    // - relies on upstream `layout_of` not having had the non-array logic removed
280    //
281    // FIXME(eddyb) remove this once migrating beyond `#[repr(simd)]` becomes
282    // an option (may require Rust-GPU distinguishing between "SPIR-V interface"
283    // and "Rust-facing" types, which is even worse when the `OpTypeVector`s
284    // may be e.g. nested in `struct`s/arrays/etc. - at least buffers are easy).
285    //
286    // FIXME(eddyb) maybe using `#[spirv(vector)]` and `BackendRepr::Memory`,
287    // no claims at `rustc`-understood SIMD whatsoever, would be enough?
288    // (i.e. only SPIR-V caring about such a type vs a struct/array)
289    providers.check_well_formed = |tcx, def_id| {
290        let trivial_struct = match tcx.hir_node_by_def_id(def_id) {
291            rustc_hir::Node::Item(item) => match item.kind {
292                rustc_hir::ItemKind::Struct(
293                    _,
294                    &rustc_hir::Generics {
295                        params:
296                            &[]
297                            | &[
298                                rustc_hir::GenericParam {
299                                    kind:
300                                        rustc_hir::GenericParamKind::Type {
301                                            default: None,
302                                            synthetic: false,
303                                        },
304                                    ..
305                                },
306                            ],
307                        predicates: &[],
308                        has_where_clause_predicates: false,
309                        where_clause_span: _,
310                        span: _,
311                    },
312                    _,
313                ) => Some(tcx.adt_def(def_id)),
314                _ => None,
315            },
316            _ => None,
317        };
318        let valid_non_array_simd_struct = trivial_struct.is_some_and(|adt_def| {
319            let ReprOptions {
320                int: None,
321                align: None,
322                pack: None,
323                flags: ReprFlags::IS_SIMD,
324                field_shuffle_seed: _,
325            } = adt_def.repr()
326            else {
327                return false;
328            };
329            if adt_def.destructor(tcx).is_some() {
330                return false;
331            }
332
333            let field_types = adt_def
334                .non_enum_variant()
335                .fields
336                .iter()
337                .map(|f| tcx.type_of(f.did).instantiate_identity());
338            field_types.dedup().exactly_one().is_ok_and(|elem_ty| {
339                matches!(
340                    elem_ty.kind(),
341                    ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Param(_)
342                )
343            })
344        });
345
346        if valid_non_array_simd_struct {
347            tcx.dcx()
348                .struct_span_warn(
349                    tcx.def_span(def_id),
350                    "[Rust-GPU] temporarily re-allowing old-style `#[repr(simd)]` (with fields)",
351                )
352                .with_note("removed upstream by https://github.com/rust-lang/rust/pull/129403")
353                .with_note("in favor of the new `#[repr(simd)] struct TxN([T; N]);` style")
354                .with_note("(taking effect since `nightly-2024-09-12` / `1.83.0` stable)")
355                .emit();
356            return Ok(());
357        }
358
359        (rustc_interface::DEFAULT_QUERY_PROVIDERS.check_well_formed)(tcx, def_id)
360    };
361
362    // HACK(eddyb) work around https://github.com/rust-lang/rust/pull/132173
363    // (and further changes from https://github.com/rust-lang/rust/pull/132843)
364    // starting to ban SIMD ABI misuse (or at least starting to warn about it).
365    //
366    // FIXME(eddyb) same as the FIXME comment on `check_well_formed`:
367    // need to migrate away from `#[repr(simd)]` ASAP.
368    providers.check_mono_item = |_, _| {};
369}
370
371/// If a struct contains a pointer to itself, even indirectly, then doing a naiive recursive walk
372/// of the fields will result in an infinite loop. Because pointers are the only thing that are
373/// allowed to be recursive, keep track of what pointers we've translated, or are currently in the
374/// progress of translating, and break the recursion that way. This struct manages that state
375/// tracking.
376#[derive(Default)]
377pub struct RecursivePointeeCache<'tcx> {
378    map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
379}
380
381impl<'tcx> RecursivePointeeCache<'tcx> {
382    fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
383        match self.map.borrow_mut().entry(pointee) {
384            // State: This is the first time we've seen this type. Record that we're beginning to translate this type,
385            // and start doing the translation.
386            Entry::Vacant(entry) => {
387                entry.insert(PointeeDefState::Defining);
388                None
389            }
390            Entry::Occupied(mut entry) => match *entry.get() {
391                // State: This is the second time we've seen this type, and we're already translating this type. If we
392                // were to try to translate the type now, we'd get a stack overflow, due to continually recursing. So,
393                // emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
394                PointeeDefState::Defining => {
395                    let new_id = cx.emit_global().id();
396                    // NOTE(eddyb) we emit `StorageClass::Generic` here, but later
397                    // the linker will specialize the entire SPIR-V module to use
398                    // storage classes inferred from `OpVariable`s.
399                    cx.emit_global()
400                        .type_forward_pointer(new_id, StorageClass::Generic);
401                    entry.insert(PointeeDefState::DefiningWithForward(new_id));
402                    cx.zombie_with_span(
403                        new_id,
404                        span,
405                        "cannot create self-referential types, even through pointers",
406                    );
407                    Some(new_id)
408                }
409                // State: This is the third or more time we've seen this type, and we've already emitted an
410                // OpTypeForwardPointer. Just use the ID we've already emitted. (Alternatively, we already defined this
411                // type, so just use that.)
412                PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
413            },
414        }
415    }
416
417    fn end(
418        &self,
419        cx: &CodegenCx<'tcx>,
420        span: Span,
421        pointee: PointeeTy<'tcx>,
422        pointee_spv: Word,
423    ) -> Word {
424        match self.map.borrow_mut().entry(pointee) {
425            // We should have hit begin() on this type already, which always inserts an entry.
426            Entry::Vacant(_) => {
427                span_bug!(span, "RecursivePointeeCache::end should always have entry")
428            }
429            Entry::Occupied(mut entry) => match *entry.get() {
430                // State: There have been no recursive references to this type while defining it, and so no
431                // OpTypeForwardPointer has been emitted. This is the most common case.
432                PointeeDefState::Defining => {
433                    let id = SpirvType::Pointer {
434                        pointee: pointee_spv,
435                    }
436                    .def(span, cx);
437                    entry.insert(PointeeDefState::Defined(id));
438                    id
439                }
440                // State: There was a recursive reference to this type, and so an OpTypeForwardPointer has been emitted.
441                // Make sure to use the same ID.
442                PointeeDefState::DefiningWithForward(id) => {
443                    entry.insert(PointeeDefState::Defined(id));
444                    SpirvType::Pointer {
445                        pointee: pointee_spv,
446                    }
447                    .def_with_id(cx, span, id)
448                }
449                PointeeDefState::Defined(_) => {
450                    span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
451                }
452            },
453        }
454    }
455}
456
457#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
458enum PointeeTy<'tcx> {
459    Ty(TyAndLayout<'tcx>),
460    Fn(PolyFnSig<'tcx>),
461}
462
463impl fmt::Display for PointeeTy<'_> {
464    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465        match self {
466            PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
467            PointeeTy::Fn(ty) => write!(f, "{ty}"),
468        }
469    }
470}
471
472enum PointeeDefState {
473    Defining,
474    DefiningWithForward(Word),
475    Defined(Word),
476}
477
478/// Various type-like things can be converted to a spirv type - normal types, function types, etc. - and this trait
479/// provides a uniform way of translating them.
480pub trait ConvSpirvType<'tcx> {
481    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
482}
483
484impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
485    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
486        match *self {
487            PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
488            PointeeTy::Fn(ty) => cx
489                .fn_abi_of_fn_ptr(ty, ty::List::empty())
490                .spirv_type(span, cx),
491        }
492    }
493}
494
495impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
496    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
497        // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
498        let mut argument_types = Vec::new();
499
500        let return_type = match self.ret.mode {
501            PassMode::Ignore => SpirvType::Void.def(span, cx),
502            PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
503            PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
504                span,
505                "query hooks should've made this `PassMode` impossible: {:#?}",
506                self.ret
507            ),
508        };
509
510        for arg in self.args.iter() {
511            let arg_type = match arg.mode {
512                PassMode::Ignore => continue,
513                PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
514                PassMode::Pair(_, _) => {
515                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
516                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
517                    continue;
518                }
519                PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
520                    span,
521                    "query hooks should've made this `PassMode` impossible: {:#?}",
522                    arg
523                ),
524            };
525            argument_types.push(arg_type);
526        }
527
528        SpirvType::Function {
529            return_type,
530            arguments: &argument_types,
531        }
532        .def(span, cx)
533    }
534}
535
536impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
537    fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
538        if let TyKind::Adt(adt, args) = *self.ty.kind() {
539            if span == DUMMY_SP {
540                span = cx.tcx.def_span(adt.did());
541            }
542
543            let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs_unchecked(adt.did()));
544
545            if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
546                && let Ok(spirv_type) =
547                    trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
548            {
549                return spirv_type;
550            }
551        }
552
553        // Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
554        // `ScalarPair`.
555        // There's a few layers that we go through here. First we inspect layout.backend_repr, then if relevant, layout.fields, etc.
556        match self.backend_repr {
557            _ if self.uninhabited => SpirvType::Adt {
558                def_id: def_id_for_spirv_type_adt(*self),
559                size: Some(Size::ZERO),
560                align: Align::from_bytes(0).unwrap(),
561                field_types: &[],
562                field_offsets: &[],
563                field_names: None,
564            }
565            .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
566            BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
567            BackendRepr::ScalarPair(a, b) => {
568                // NOTE(eddyb) unlike `BackendRepr::Scalar`'s simpler newtype-unpacking
569                // behavior, `BackendRepr::ScalarPair` can be composed in two ways:
570                // * two `BackendRepr::Scalar` fields (and any number of ZST fields),
571                //   gets handled the same as a `struct { a, b }`, further below
572                // * an `BackendRepr::ScalarPair` field (and any number of ZST fields),
573                //   which requires more work to allow taking a reference to
574                //   that field, and there are two potential approaches:
575                //   1. wrapping that field's SPIR-V type in a single-field
576                //      `OpTypeStruct` - this has the disadvantage that GEPs
577                //      would have to inject an extra `0` field index, and other
578                //      field-related operations would also need additional work
579                //   2. reusing that field's SPIR-V type, instead of defining
580                //      a new one, offering the `(a, b)` shape `rustc_codegen_ssa`
581                //      expects, while letting noop pointercasts access the sole
582                //      `BackendRepr::ScalarPair` field - this is the approach taken here
583                let mut non_zst_fields = (0..self.fields.count())
584                    .map(|i| (i, self.field(cx, i)))
585                    .filter(|(_, field)| !field.is_zst());
586                let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
587                    (Some(field), None) => Some(field),
588                    _ => None,
589                };
590                if let Some((i, field)) = sole_non_zst_field {
591                    // Only unpack a newtype if the field and the newtype line up
592                    // perfectly, in every way that could potentially affect ABI.
593                    if self.fields.offset(i) == Size::ZERO
594                        && field.size == self.size
595                        && field.align.abi == self.align.abi
596                        && field.backend_repr.eq_up_to_validity(&self.backend_repr)
597                    {
598                        return field.spirv_type(span, cx);
599                    }
600                }
601
602                // Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
603                // recursive pointer types.
604                let a_offset = Size::ZERO;
605                let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
606                let a = trans_scalar(cx, span, *self, a, a_offset);
607                let b = trans_scalar(cx, span, *self, b, b_offset);
608                let size = if self.is_unsized() {
609                    None
610                } else {
611                    Some(self.size)
612                };
613                // FIXME(eddyb) use `ArrayVec` here.
614                let mut field_names = Vec::new();
615                if let TyKind::Adt(adt, _) = self.ty.kind()
616                    && let Variants::Single { index } = self.variants
617                {
618                    for i in self.fields.index_by_increasing_offset() {
619                        let field = &adt.variants()[index].fields[FieldIdx::new(i)];
620                        field_names.push(field.name);
621                    }
622                }
623                SpirvType::Adt {
624                    def_id: def_id_for_spirv_type_adt(*self),
625                    size,
626                    align: self.align.abi,
627                    field_types: &[a, b],
628                    field_offsets: &[a_offset, b_offset],
629                    field_names: if field_names.len() == 2 {
630                        Some(&field_names)
631                    } else {
632                        None
633                    },
634                }
635                .def_with_name(cx, span, TyLayoutNameKey::from(*self))
636            }
637            BackendRepr::SimdVector { element, count } => {
638                let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
639                SpirvType::Vector {
640                    element: elem_spirv,
641                    count: count as u32,
642                }
643                .def(span, cx)
644            }
645            BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
646        }
647    }
648}
649
650/// Only pub for `LayoutTypeCodegenMethods::scalar_pair_element_backend_type`. Think about what you're
651/// doing before calling this.
652pub fn scalar_pair_element_backend_type<'tcx>(
653    cx: &CodegenCx<'tcx>,
654    span: Span,
655    ty: TyAndLayout<'tcx>,
656    index: usize,
657) -> Word {
658    let [a, b] = match ty.layout.backend_repr() {
659        BackendRepr::ScalarPair(a, b) => [a, b],
660        other => span_bug!(
661            span,
662            "scalar_pair_element_backend_type invalid abi: {:?}",
663            other
664        ),
665    };
666    let offset = match index {
667        0 => Size::ZERO,
668        1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
669        _ => unreachable!(),
670    };
671    trans_scalar(cx, span, ty, [a, b][index], offset)
672}
673
674/// A "scalar" is a basic building block: bools, ints, floats, pointers. (i.e. not something complex like a struct)
675/// A "scalar pair" is a bit of a strange concept: if there is a `fn f(x: (u32, u32))`, then what's preferred for
676/// performance is to compile that ABI to `f(x_1: u32, x_2: u32)`, i.e. splitting out the pair into their own arguments,
677/// and pretending that they're one unit. So, there's quite a bit of special handling around these scalar pairs to enable
678/// scenarios like that.
679/// I say it's "preferred", but spirv doesn't really care - only CPU ABIs really care here. However, following rustc's
680/// lead and doing what they want makes things go smoothly, so we'll implement it here too.
681fn trans_scalar<'tcx>(
682    cx: &CodegenCx<'tcx>,
683    span: Span,
684    ty: TyAndLayout<'tcx>,
685    scalar: Scalar,
686    offset: Size,
687) -> Word {
688    if scalar.is_bool() {
689        return SpirvType::Bool.def(span, cx);
690    }
691
692    match scalar.primitive() {
693        Primitive::Int(int_kind, signedness) => {
694            SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
695        }
696        Primitive::Float(float_kind) => {
697            SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
698        }
699        Primitive::Pointer(_) => {
700            let pointee_ty = dig_scalar_pointee(cx, ty, offset);
701            // Pointers can be recursive. So, record what we're currently translating, and if we're already translating
702            // the same type, emit an OpTypeForwardPointer and use that ID.
703            if let Some(predefined_result) = cx
704                .type_cache
705                .recursive_pointee_cache
706                .begin(cx, span, pointee_ty)
707            {
708                predefined_result
709            } else {
710                let pointee = pointee_ty.spirv_type(span, cx);
711                cx.type_cache
712                    .recursive_pointee_cache
713                    .end(cx, span, pointee_ty, pointee)
714            }
715        }
716    }
717}
718
719// This is a really weird function, strap in...
720// So, rustc_codegen_ssa is designed around scalar pointers being opaque, you shouldn't know the type behind the
721// pointer. Unfortunately, that's impossible for us, we need to know the underlying pointee type for various reasons. In
722// some cases, this is pretty easy - if it's a TyKind::Ref, then the pointee will be the pointee of the ref (with
723// handling for wide pointers, etc.). Unfortunately, there's some pretty advanced processing going on in cx.layout_of:
724// for example, `ManuallyDrop<Result<ptr, ptr>>` has abi `ScalarPair`. This means that to figure out the pointee type,
725// we have to replicate the logic of cx.layout_of. Part of that is digging into types that are aggregates: for example,
726// ManuallyDrop<T> has a single field of type T. We "dig into" that field, and recurse, trying to find a base case that
727// we can handle, like TyKind::Ref.
728// If the above didn't make sense, please poke Ashley, it's probably easier to explain via conversation.
729fn dig_scalar_pointee<'tcx>(
730    cx: &CodegenCx<'tcx>,
731    layout: TyAndLayout<'tcx>,
732    offset: Size,
733) -> PointeeTy<'tcx> {
734    if let FieldsShape::Primitive = layout.fields {
735        assert_eq!(offset, Size::ZERO);
736        let pointee = match *layout.ty.kind() {
737            TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
738                PointeeTy::Ty(cx.layout_of(pointee_ty))
739            }
740            TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
741            _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
742        };
743        return pointee;
744    }
745
746    let all_fields = (match &layout.variants {
747        Variants::Empty => 0..0,
748        Variants::Multiple { variants, .. } => 0..variants.len(),
749        Variants::Single { index } => {
750            let i = index.as_usize();
751            i..i + 1
752        }
753    })
754    .flat_map(|variant_idx| {
755        let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
756        (0..variant.fields.count()).map(move |field_idx| {
757            (
758                variant.field(cx, field_idx),
759                variant.fields.offset(field_idx),
760            )
761        })
762    });
763
764    let mut pointee = None;
765    for (field, field_offset) in all_fields {
766        if field.is_zst() {
767            continue;
768        }
769        if (field_offset..field_offset + field.size).contains(&offset) {
770            let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
771            match pointee {
772                Some(old_pointee) if old_pointee != new_pointee => {
773                    cx.tcx.dcx().fatal(format!(
774                        "dig_scalar_pointee: unsupported Pointer with different \
775                         pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
776                    ));
777                }
778                _ => pointee = Some(new_pointee),
779            }
780        }
781    }
782    pointee.unwrap_or_else(|| {
783        bug!(
784            "field containing Pointer scalar at offset {:?} not found in {:#?}",
785            offset,
786            layout
787        )
788    })
789}
790
791// FIXME(eddyb) all `ty: TyAndLayout` variables should be `layout: TyAndLayout`,
792// the type is really more "Layout with Ty" (`.ty` field + `Deref`s to `Layout`).
793fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
794    fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
795        assert_eq!(ty.size, Size::ZERO);
796        SpirvType::Adt {
797            def_id: def_id_for_spirv_type_adt(ty),
798            size: Some(Size::ZERO),
799            align: ty.align.abi,
800            field_types: &[],
801            field_offsets: &[],
802            field_names: None,
803        }
804        .def_with_name(cx, span, TyLayoutNameKey::from(ty))
805    }
806    match ty.fields {
807        FieldsShape::Primitive => span_bug!(
808            span,
809            "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
810            ty
811        ),
812        FieldsShape::Union(_) => {
813            assert!(!ty.is_unsized(), "{ty:#?}");
814
815            // Represent the `union` with its largest case, which should work
816            // for at least `MaybeUninit<T>` (which is between `T` and `()`),
817            // but also potentially some other ones as well.
818            // NOTE(eddyb) even if long-term this may become a byte array, that
819            // only works for "data types" and not "opaque handles" (images etc.).
820            let largest_case = (0..ty.fields.count())
821                .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
822                .max_by_key(|(_, case)| case.size);
823
824            if let Some((case_idx, case)) = largest_case {
825                if ty.align != case.align {
826                    // HACK(eddyb) mismatched alignment requires a wrapper `struct`.
827                    trans_struct_or_union(cx, span, ty, Some(case_idx))
828                } else {
829                    assert_eq!(ty.size, case.size);
830                    case.spirv_type(span, cx)
831                }
832            } else {
833                create_zst(cx, span, ty)
834            }
835        }
836        FieldsShape::Array { stride, count } => {
837            let element_type = ty.field(cx, 0).spirv_type(span, cx);
838            if ty.is_unsized() {
839                // There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
840                // However, I think rust disallows all these cases, so assert this here.
841                assert_eq!(count, 0);
842                SpirvType::RuntimeArray {
843                    element: element_type,
844                }
845                .def(span, cx)
846            } else if count == 0 {
847                // spir-v doesn't support zero-sized arrays
848                create_zst(cx, span, ty)
849            } else {
850                let count_const = cx.constant_u32(span, count as u32);
851                let element_spv = cx.lookup_type(element_type);
852                let stride_spv = element_spv
853                    .sizeof(cx)
854                    .expect("Unexpected unsized type in sized FieldsShape::Array")
855                    .align_to(element_spv.alignof(cx));
856                assert_eq!(stride_spv, stride);
857                SpirvType::Array {
858                    element: element_type,
859                    count: count_const,
860                }
861                .def(span, cx)
862            }
863        }
864        FieldsShape::Arbitrary {
865            offsets: _,
866            memory_index: _,
867        } => trans_struct_or_union(cx, span, ty, None),
868    }
869}
870
871#[cfg_attr(
872    not(rustc_codegen_spirv_disable_pqp_cg_ssa),
873    expect(
874        unused,
875        reason = "actually used from \
876            `<rustc_codegen_ssa::traits::ConstCodegenMethods for CodegenCx<'_>>::const_struct`, \
877            but `rustc_codegen_ssa` being `pqp_cg_ssa` makes that trait unexported"
878    )
879)]
880// returns (field_offsets, size, align)
881pub fn auto_struct_layout(
882    cx: &CodegenCx<'_>,
883    field_types: &[Word],
884) -> (Vec<Size>, Option<Size>, Align) {
885    // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
886    let mut field_offsets = Vec::with_capacity(field_types.len());
887    let mut offset = Some(Size::ZERO);
888    let mut max_align = Align::from_bytes(0).unwrap();
889    for &field_type in field_types {
890        let spirv_type = cx.lookup_type(field_type);
891        let field_size = spirv_type.sizeof(cx);
892        let field_align = spirv_type.alignof(cx);
893        let this_offset = offset
894            .expect("Unsized values can only be the last field in a struct")
895            .align_to(field_align);
896
897        field_offsets.push(this_offset);
898        if field_align > max_align {
899            max_align = field_align;
900        }
901        offset = field_size.map(|size| this_offset + size);
902    }
903    (field_offsets, offset, max_align)
904}
905
906// see struct_llfields in librustc_codegen_llvm for implementation hints
907fn trans_struct_or_union<'tcx>(
908    cx: &CodegenCx<'tcx>,
909    span: Span,
910    ty: TyAndLayout<'tcx>,
911    union_case: Option<FieldIdx>,
912) -> Word {
913    let size = if ty.is_unsized() { None } else { Some(ty.size) };
914    let align = ty.align.abi;
915    // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
916    let mut field_types = Vec::new();
917    let mut field_offsets = Vec::new();
918    let mut field_names = Vec::new();
919    for i in ty.fields.index_by_increasing_offset() {
920        if let Some(expected_field_idx) = union_case
921            && i != expected_field_idx.as_usize()
922        {
923            continue;
924        }
925
926        let field_ty = ty.field(cx, i);
927        field_types.push(field_ty.spirv_type(span, cx));
928        let offset = ty.fields.offset(i);
929        field_offsets.push(offset);
930        if let Variants::Single { index } = ty.variants {
931            if let TyKind::Adt(adt, _) = ty.ty.kind() {
932                let field = &adt.variants()[index].fields[FieldIdx::new(i)];
933                field_names.push(field.name);
934            } else {
935                // FIXME(eddyb) this looks like something that should exist in rustc.
936                field_names.push(Symbol::intern(&format!("{i}")));
937            }
938        } else {
939            if let TyKind::Adt(_, _) = ty.ty.kind() {
940            } else {
941                span_bug!(span, "Variants::Multiple not TyKind::Adt");
942            }
943            if i == 0 {
944                field_names.push(cx.sym.discriminant);
945            } else {
946                cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
947            }
948        };
949    }
950    SpirvType::Adt {
951        def_id: def_id_for_spirv_type_adt(ty),
952        size,
953        align,
954        field_types: &field_types,
955        field_offsets: &field_offsets,
956        field_names: Some(&field_names),
957    }
958    .def_with_name(cx, span, TyLayoutNameKey::from(ty))
959}
960
961/// Grab a `DefId` from the type if possible to avoid too much deduplication,
962/// which could result in one SPIR-V `OpType*` having many names
963/// (not in itself an issue, but it makes error reporting harder).
964fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
965    match *layout.ty.kind() {
966        TyKind::Adt(def, _) => Some(def.did()),
967        TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
968            Some(def_id)
969        }
970        _ => None,
971    }
972}
973
974/// Minimal and cheaply comparable/hashable subset of the information contained
975/// in `TyLayout` that can be used to generate a name (assuming a nominal type).
976#[derive(Copy, Clone, PartialEq, Eq, Hash)]
977pub struct TyLayoutNameKey<'tcx> {
978    ty: Ty<'tcx>,
979    variant: Option<VariantIdx>,
980}
981
982impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
983    fn from(layout: TyAndLayout<'tcx>) -> Self {
984        TyLayoutNameKey {
985            ty: layout.ty,
986            variant: match layout.variants {
987                Variants::Single { index } => Some(index),
988                _ => None,
989            },
990        }
991    }
992}
993
994impl fmt::Display for TyLayoutNameKey<'_> {
995    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
996        write!(f, "{}", self.ty)?;
997        if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
998            && def.is_enum()
999            && !def.variants().is_empty()
1000        {
1001            write!(f, "::{}", def.variants()[index].name)?;
1002        }
1003        if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
1004            write!(f, "::{}", CoroutineArgs::variant_name(index))?;
1005        }
1006        Ok(())
1007    }
1008}
1009
1010fn trans_intrinsic_type<'tcx>(
1011    cx: &CodegenCx<'tcx>,
1012    span: Span,
1013    ty: TyAndLayout<'tcx>,
1014    args: GenericArgsRef<'tcx>,
1015    intrinsic_type_attr: IntrinsicType,
1016) -> Result<Word, ErrorGuaranteed> {
1017    match intrinsic_type_attr {
1018        IntrinsicType::GenericImageType => {
1019            // see SpirvType::sizeof
1020            if ty.size != Size::from_bytes(4) {
1021                return Err(cx
1022                    .tcx
1023                    .dcx()
1024                    .err("#[spirv(generic_image)] type must have size 4"));
1025            }
1026
1027            // fn type_from_variant_discriminant<'tcx, P: FromPrimitive>(
1028            //     cx: &CodegenCx<'tcx>,
1029            //     const_: Const<'tcx>,
1030            // ) -> P {
1031            //     let adt_def = const_.ty.ty_adt_def().unwrap();
1032            //     assert!(adt_def.is_enum());
1033            //     let destructured = cx.tcx.destructure_const(TypingEnv::fully_monomorphized().and(const_));
1034            //     let idx = destructured.variant.unwrap();
1035            //     let value = const_.ty.discriminant_for_variant(cx.tcx, idx).unwrap().val as u64;
1036            //     <_>::from_u64(value).unwrap()
1037            // }
1038
1039            let sampled_type = match args.type_at(0).kind() {
1040                TyKind::Int(int) => match int {
1041                    IntTy::Isize => {
1042                        SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, true)
1043                            .def(span, cx)
1044                    }
1045                    IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
1046                    IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
1047                    IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
1048                    IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
1049                    IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
1050                },
1051                TyKind::Uint(uint) => match uint {
1052                    UintTy::Usize => {
1053                        SpirvType::Integer(cx.tcx.data_layout.pointer_size.bits() as u32, false)
1054                            .def(span, cx)
1055                    }
1056                    UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
1057                    UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
1058                    UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
1059                    UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
1060                    UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
1061                },
1062                TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
1063                TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
1064                _ => {
1065                    return Err(cx
1066                        .tcx
1067                        .dcx()
1068                        .span_err(span, "Invalid sampled type to `Image`."));
1069                }
1070            };
1071
1072            // let dim: spirv::Dim = type_from_variant_discriminant(cx, args.const_at(1));
1073            // let depth: u32 = type_from_variant_discriminant(cx, args.const_at(2));
1074            // let arrayed: u32 = type_from_variant_discriminant(cx, args.const_at(3));
1075            // let multisampled: u32 = type_from_variant_discriminant(cx, args.const_at(4));
1076            // let sampled: u32 = type_from_variant_discriminant(cx, args.const_at(5));
1077            // let image_format: spirv::ImageFormat =
1078            //     type_from_variant_discriminant(cx, args.const_at(6));
1079
1080            trait FromScalarInt: Sized {
1081                fn from_scalar_int(n: ScalarInt) -> Option<Self>;
1082            }
1083
1084            impl FromScalarInt for u32 {
1085                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1086                    Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
1087                }
1088            }
1089
1090            impl FromScalarInt for Dim {
1091                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1092                    Dim::from_u32(u32::from_scalar_int(n)?)
1093                }
1094            }
1095
1096            impl FromScalarInt for ImageFormat {
1097                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
1098                    ImageFormat::from_u32(u32::from_scalar_int(n)?)
1099                }
1100            }
1101
1102            fn const_int_value<'tcx, P: FromScalarInt>(
1103                cx: &CodegenCx<'tcx>,
1104                const_: Const<'tcx>,
1105            ) -> Result<P, ErrorGuaranteed> {
1106                let ty::Value {
1107                    ty: const_ty,
1108                    valtree: const_val,
1109                } = const_.to_value();
1110                assert!(const_ty.is_integral());
1111                const_val
1112                    .try_to_scalar_int()
1113                    .and_then(P::from_scalar_int)
1114                    .ok_or_else(|| {
1115                        cx.tcx
1116                            .dcx()
1117                            .err(format!("invalid value for Image const generic: {const_}"))
1118                    })
1119            }
1120
1121            let dim = const_int_value(cx, args.const_at(1))?;
1122            let depth = const_int_value(cx, args.const_at(2))?;
1123            let arrayed = const_int_value(cx, args.const_at(3))?;
1124            let multisampled = const_int_value(cx, args.const_at(4))?;
1125            let sampled = const_int_value(cx, args.const_at(5))?;
1126            let image_format = const_int_value(cx, args.const_at(6))?;
1127
1128            let ty = SpirvType::Image {
1129                sampled_type,
1130                dim,
1131                depth,
1132                arrayed,
1133                multisampled,
1134                sampled,
1135                image_format,
1136            };
1137            Ok(ty.def(span, cx))
1138        }
1139        IntrinsicType::Sampler => {
1140            // see SpirvType::sizeof
1141            if ty.size != Size::from_bytes(4) {
1142                return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
1143            }
1144            Ok(SpirvType::Sampler.def(span, cx))
1145        }
1146        IntrinsicType::AccelerationStructureKhr => {
1147            Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
1148        }
1149        IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
1150        IntrinsicType::SampledImage => {
1151            // see SpirvType::sizeof
1152            if ty.size != Size::from_bytes(4) {
1153                return Err(cx
1154                    .tcx
1155                    .dcx()
1156                    .err("#[spirv(sampled_image)] type must have size 4"));
1157            }
1158
1159            // We use a generic to indicate the underlying image type of the sampled image.
1160            // The spirv type of it will be generated by querying the type of the first generic.
1161            if let Some(image_ty) = args.types().next() {
1162                // TODO: enforce that the generic param is an image type?
1163                let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
1164                Ok(SpirvType::SampledImage { image_type }.def(span, cx))
1165            } else {
1166                Err(cx
1167                    .tcx
1168                    .dcx()
1169                    .err("#[spirv(sampled_image)] type must have a generic image type"))
1170            }
1171        }
1172        IntrinsicType::RuntimeArray => {
1173            if ty.size != Size::from_bytes(4) {
1174                return Err(cx
1175                    .tcx
1176                    .dcx()
1177                    .err("#[spirv(runtime_array)] type must have size 4"));
1178            }
1179
1180            // We use a generic param to indicate the underlying element type.
1181            // The SPIR-V element type will be generated from the first generic param.
1182            if let Some(elem_ty) = args.types().next() {
1183                Ok(SpirvType::RuntimeArray {
1184                    element: cx.layout_of(elem_ty).spirv_type(span, cx),
1185                }
1186                .def(span, cx))
1187            } else {
1188                Err(cx
1189                    .tcx
1190                    .dcx()
1191                    .err("#[spirv(runtime_array)] type must have a generic element type"))
1192            }
1193        }
1194        IntrinsicType::TypedBuffer => {
1195            if ty.size != Size::from_bytes(4) {
1196                return Err(cx
1197                    .tcx
1198                    .sess
1199                    .dcx()
1200                    .err("#[spirv(typed_buffer)] type must have size 4"));
1201            }
1202
1203            // We use a generic param to indicate the underlying data type.
1204            // The SPIR-V data type will be generated from the first generic param.
1205            if let Some(data_ty) = args.types().next() {
1206                // HACK(eddyb) this should be a *pointer* to an "interface block",
1207                // but SPIR-V screwed up and used no explicit indirection for the
1208                // descriptor indexing case, and instead made a `RuntimeArray` of
1209                // `InterfaceBlock`s be an "array of typed buffer resources".
1210                Ok(SpirvType::InterfaceBlock {
1211                    inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
1212                }
1213                .def(span, cx))
1214            } else {
1215                Err(cx
1216                    .tcx
1217                    .sess
1218                    .dcx()
1219                    .err("#[spirv(typed_buffer)] type must have a generic data type"))
1220            }
1221        }
1222        IntrinsicType::Matrix => {
1223            let span = def_id_for_spirv_type_adt(ty)
1224                .map(|did| cx.tcx.def_span(did))
1225                .expect("#[spirv(matrix)] must be added to a type which has DefId");
1226
1227            let field_types = (0..ty.fields.count())
1228                .map(|i| ty.field(cx, i).spirv_type(span, cx))
1229                .collect::<Vec<_>>();
1230            if field_types.len() < 2 {
1231                return Err(cx
1232                    .tcx
1233                    .dcx()
1234                    .span_err(span, "#[spirv(matrix)] type must have at least two fields"));
1235            }
1236            let elem_type = field_types[0];
1237            if !field_types.iter().all(|&ty| ty == elem_type) {
1238                return Err(cx.tcx.dcx().span_err(
1239                    span,
1240                    "#[spirv(matrix)] type fields must all be the same type",
1241                ));
1242            }
1243            match cx.lookup_type(elem_type) {
1244                SpirvType::Vector { .. } => (),
1245                ty => {
1246                    return Err(cx
1247                        .tcx
1248                        .dcx()
1249                        .struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
1250                        .with_note(format!("field type is {}", ty.debug(elem_type, cx)))
1251                        .emit());
1252                }
1253            }
1254
1255            Ok(SpirvType::Matrix {
1256                element: elem_type,
1257                count: field_types.len() as u32,
1258            }
1259            .def(span, cx))
1260        }
1261    }
1262}