Skip to main content

rustc_codegen_spirv/
abi.rs

1//! This file is responsible for translation from rustc tys (`TyAndLayout`) to spir-v types. It's
2//! surprisingly difficult.
3
4use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
5use crate::codegen_cx::CodegenCx;
6use crate::spirv_type::SpirvType;
7use itertools::Itertools;
8use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
9use rustc_abi::ExternAbi as Abi;
10use rustc_abi::{
11    Align, BackendRepr, FieldIdx, FieldsShape, Primitive, Scalar, Size, VariantIdx, Variants,
12};
13use rustc_data_structures::fx::FxHashMap;
14use rustc_errors::ErrorGuaranteed;
15use rustc_index::Idx;
16use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
17use rustc_middle::ty::{
18    self, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty, TyCtxt,
19    TyKind, UintTy, ValTreeKindExt,
20};
21use rustc_middle::ty::{GenericArgsRef, ScalarInt};
22use rustc_middle::util::Providers;
23use rustc_middle::{bug, span_bug};
24use rustc_session::config::OptLevel;
25use rustc_span::DUMMY_SP;
26use rustc_span::def_id::DefId;
27use rustc_span::{Span, Symbol};
28use rustc_target::callconv::{ArgAbi, ArgAttributes, FnAbi, PassMode};
29use std::cell::RefCell;
30use std::collections::hash_map::Entry;
31use std::fmt;
32
33fn rewrite_c_abi_to_rust<'tcx>(
34    fn_sig: ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>>,
35) -> ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>> {
36    fn_sig.map_bound(|outer| {
37        outer.map_bound(|mut inner| {
38            if let Abi::C { .. } = inner.abi {
39                inner.abi = Abi::Rust;
40            }
41            inner
42        })
43    })
44}
45
46pub(crate) fn provide(providers: &mut Providers) {
47    // This is a lil weird: so, we obviously don't support C ABIs at all. However, libcore does declare some extern
48    // C functions:
49    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/library/core/src/slice/cmp.rs#L119
50    // However, those functions will be implemented by compiler-builtins:
51    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/library/core/src/lib.rs#L23-L27
52    // This theoretically then should be fine to leave as C, but, there's no backend hook for
53    // `FnAbi::adjust_for_cabi`, causing it to panic:
54    // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/compiler/rustc_target/src/abi/call/mod.rs#L603
55    // So, treat any `extern "C"` functions as `extern "Rust"`, to be able to
56    // compile libcore with arch=spirv.
57    //
58    // NOTE: this used to rewrite to `extern "unadjusted"`, but rustc now
59    // validates `#[rustc_pass_indirectly_in_non_rustic_abis]` for non-Rust ABIs,
60    // and `Unadjusted` does not satisfy that requirement.
61    providers.queries.fn_sig = |tcx, def_id| {
62        // We can't capture the old fn_sig and just call that, because fn_sig is a `fn`, not a `Fn`, i.e. it can't
63        // capture variables. Fortunately, the defaults are exposed (thanks rustdoc), so use that instead.
64        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS.queries.fn_sig)(tcx, def_id);
65        rewrite_c_abi_to_rust(result)
66    };
67    providers.extern_queries.fn_sig = |tcx, def_id| {
68        // We can't capture the old fn_sig and just call that, because fn_sig is a `fn`, not a `Fn`, i.e. it can't
69        // capture variables. Fortunately, the defaults are exposed (thanks rustdoc), so use that instead.
70        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
71            .extern_queries
72            .fn_sig)(tcx, def_id);
73        rewrite_c_abi_to_rust(result)
74    };
75
76    // For the Rust ABI, `FnAbi` adjustments are backend-agnostic, but they will
77    // use features like `PassMode::Cast`, that are incompatible with SPIR-V.
78    // By hooking the queries computing `FnAbi`s, we can recompute the `FnAbi`
79    // from the return/args layouts, to e.g. prefer using `PassMode::Direct`.
80    fn readjust_fn_abi<'tcx>(
81        tcx: TyCtxt<'tcx>,
82        fn_abi: &'tcx FnAbi<'tcx, Ty<'tcx>>,
83    ) -> &'tcx FnAbi<'tcx, Ty<'tcx>> {
84        let readjust_arg_abi = |arg: &ArgAbi<'tcx, Ty<'tcx>>| {
85            let mut arg = ArgAbi::new(&tcx, arg.layout, |_, _| ArgAttributes::new());
86            // FIXME: this is bad! https://github.com/rust-lang/rust/issues/115666
87            // <https://github.com/rust-lang/rust/commit/eaaa03faf77b157907894a4207d8378ecaec7b45>
88            arg.make_direct_deprecated();
89
90            // FIXME(eddyb) detect `#[rust_gpu::vector::v1]` more specifically,
91            // to avoid affecting anything should actually be passed as a pair.
92            if let PassMode::Pair(..) = arg.mode {
93                // HACK(eddyb) this avoids breaking e.g. `&[T]` pairs.
94                if let TyKind::Adt(..) = arg.layout.ty.kind() {
95                    arg.mode = PassMode::Direct(ArgAttributes::new());
96                }
97            }
98
99            // Avoid pointlessly passing ZSTs, just like the official Rust ABI.
100            if arg.layout.is_zst() {
101                arg.mode = PassMode::Ignore;
102            }
103
104            arg
105        };
106        tcx.arena.alloc(FnAbi {
107            args: fn_abi.args.iter().map(readjust_arg_abi).collect(),
108            ret: readjust_arg_abi(&fn_abi.ret),
109
110            // FIXME(eddyb) validate some of these, and report errors - however,
111            // we can't just emit errors from here, since we have no `Span`, so
112            // we should have instead a check on MIR for e.g. C variadic calls.
113            c_variadic: fn_abi.c_variadic,
114            fixed_count: fn_abi.fixed_count,
115            conv: fn_abi.conv,
116            can_unwind: fn_abi.can_unwind,
117        })
118    }
119    providers.queries.fn_abi_of_fn_ptr = |tcx, key| {
120        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
121            .queries
122            .fn_abi_of_fn_ptr)(tcx, key);
123        Ok(readjust_fn_abi(tcx, result?))
124    };
125    providers.queries.fn_abi_of_instance_no_deduced_attrs = |tcx, key| {
126        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
127            .queries
128            .fn_abi_of_instance_no_deduced_attrs)(tcx, key);
129        // Keep this query in its original shape while `fn_abi_of_instance_raw`
130        // is being computed: rustc validates strict invariants there.
131        // Otherwise, if `fn_abi_of_instance` would route through this query
132        // directly (e.g. incremental or opt-level=0), apply SPIR-V readjustment.
133        if tcx.sess.opts.optimize != OptLevel::No && tcx.sess.opts.incremental.is_none() {
134            result
135        } else {
136            Ok(readjust_fn_abi(tcx, result?))
137        }
138    };
139    providers.queries.fn_abi_of_instance_raw = |tcx, key| {
140        let result = (rustc_interface::DEFAULT_QUERY_PROVIDERS
141            .queries
142            .fn_abi_of_instance_raw)(tcx, key);
143        Ok(readjust_fn_abi(tcx, result?))
144    };
145
146    // HACK(eddyb) work around https://github.com/rust-lang/rust/pull/132173
147    // (and further changes from https://github.com/rust-lang/rust/pull/132843)
148    // starting to ban SIMD ABI misuse (or at least starting to warn about it).
149    //
150    // FIXME(eddyb) same as the FIXME comment on `check_well_formed`:
151    // need to migrate away from `#[repr(simd)]` ASAP.
152    providers.queries.check_mono_item = |_, _| {};
153}
154
155/// If a struct contains a pointer to itself, even indirectly, then doing a naiive recursive walk
156/// of the fields will result in an infinite loop. Because pointers are the only thing that are
157/// allowed to be recursive, keep track of what pointers we've translated, or are currently in the
158/// progress of translating, and break the recursion that way. This struct manages that state
159/// tracking.
160#[derive(Default)]
161pub struct RecursivePointeeCache<'tcx> {
162    map: RefCell<FxHashMap<PointeeTy<'tcx>, PointeeDefState>>,
163}
164
165impl<'tcx> RecursivePointeeCache<'tcx> {
166    fn begin(&self, cx: &CodegenCx<'tcx>, span: Span, pointee: PointeeTy<'tcx>) -> Option<Word> {
167        match self.map.borrow_mut().entry(pointee) {
168            // State: This is the first time we've seen this type. Record that we're beginning to translate this type,
169            // and start doing the translation.
170            Entry::Vacant(entry) => {
171                entry.insert(PointeeDefState::Defining);
172                None
173            }
174            Entry::Occupied(mut entry) => match *entry.get() {
175                // State: This is the second time we've seen this type, and we're already translating this type. If we
176                // were to try to translate the type now, we'd get a stack overflow, due to continually recursing. So,
177                // emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
178                PointeeDefState::Defining => {
179                    let new_id = cx.emit_global().id();
180                    // NOTE(eddyb) we emit `StorageClass::Generic` here, but later
181                    // the linker will specialize the entire SPIR-V module to use
182                    // storage classes inferred from `OpVariable`s.
183                    cx.emit_global()
184                        .type_forward_pointer(new_id, StorageClass::Generic);
185                    entry.insert(PointeeDefState::DefiningWithForward(new_id));
186                    cx.zombie_with_span(
187                        new_id,
188                        span,
189                        "cannot create self-referential types, even through pointers",
190                    );
191                    Some(new_id)
192                }
193                // State: This is the third or more time we've seen this type, and we've already emitted an
194                // OpTypeForwardPointer. Just use the ID we've already emitted. (Alternatively, we already defined this
195                // type, so just use that.)
196                PointeeDefState::DefiningWithForward(id) | PointeeDefState::Defined(id) => Some(id),
197            },
198        }
199    }
200
201    fn end(
202        &self,
203        cx: &CodegenCx<'tcx>,
204        span: Span,
205        pointee: PointeeTy<'tcx>,
206        pointee_spv: Word,
207    ) -> Word {
208        match self.map.borrow_mut().entry(pointee) {
209            // We should have hit begin() on this type already, which always inserts an entry.
210            Entry::Vacant(_) => {
211                span_bug!(span, "RecursivePointeeCache::end should always have entry")
212            }
213            Entry::Occupied(mut entry) => match *entry.get() {
214                // State: There have been no recursive references to this type while defining it, and so no
215                // OpTypeForwardPointer has been emitted. This is the most common case.
216                PointeeDefState::Defining => {
217                    let id = SpirvType::Pointer {
218                        pointee: pointee_spv,
219                    }
220                    .def(span, cx);
221                    entry.insert(PointeeDefState::Defined(id));
222                    id
223                }
224                // State: There was a recursive reference to this type, and so an OpTypeForwardPointer has been emitted.
225                // Make sure to use the same ID.
226                PointeeDefState::DefiningWithForward(id) => {
227                    entry.insert(PointeeDefState::Defined(id));
228                    SpirvType::Pointer {
229                        pointee: pointee_spv,
230                    }
231                    .def_with_id(cx, span, id)
232                }
233                PointeeDefState::Defined(_) => {
234                    span_bug!(span, "RecursivePointeeCache::end defined pointer twice")
235                }
236            },
237        }
238    }
239}
240
241#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
242enum PointeeTy<'tcx> {
243    Ty(TyAndLayout<'tcx>),
244    Fn(PolyFnSig<'tcx>),
245}
246
247impl fmt::Display for PointeeTy<'_> {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        match self {
250            PointeeTy::Ty(ty) => write!(f, "{}", ty.ty),
251            PointeeTy::Fn(ty) => write!(f, "{ty}"),
252        }
253    }
254}
255
256enum PointeeDefState {
257    Defining,
258    DefiningWithForward(Word),
259    Defined(Word),
260}
261
262/// Various type-like things can be converted to a spirv type - normal types, function types, etc. - and this trait
263/// provides a uniform way of translating them.
264pub trait ConvSpirvType<'tcx> {
265    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
266}
267
268impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
269    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
270        match *self {
271            PointeeTy::Ty(ty) => ty.spirv_type(span, cx),
272            PointeeTy::Fn(ty) => cx
273                .fn_abi_of_fn_ptr(ty, ty::List::empty())
274                .spirv_type(span, cx),
275        }
276    }
277}
278
279impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
280    fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
281        // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
282        let mut argument_types = Vec::new();
283
284        let return_type = match self.ret.mode {
285            PassMode::Ignore => SpirvType::Void.def(span, cx),
286            PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
287            PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
288                span,
289                "query hooks should've made this `PassMode` impossible: {:#?}",
290                self.ret
291            ),
292        };
293
294        for arg in self.args.iter() {
295            let arg_type = match arg.mode {
296                PassMode::Ignore => continue,
297                PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
298                PassMode::Pair(_, _) => {
299                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
300                    argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
301                    continue;
302                }
303                PassMode::Cast { .. } | PassMode::Indirect { .. } => span_bug!(
304                    span,
305                    "query hooks should've made this `PassMode` impossible: {:#?}",
306                    arg
307                ),
308            };
309            argument_types.push(arg_type);
310        }
311
312        SpirvType::Function {
313            return_type,
314            arguments: &argument_types,
315        }
316        .def(span, cx)
317    }
318}
319
320impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
321    fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
322        if let TyKind::Adt(adt, args) = *self.ty.kind() {
323            if span == DUMMY_SP {
324                span = cx.tcx.def_span(adt.did());
325            }
326
327            let attrs = AggregatedSpirvAttributes::parse(
328                cx,
329                cx.tcx
330                    .get_attrs_by_path(
331                        adt.did(),
332                        &[cx.sym.rust_gpu, cx.sym.spirv_attr_with_version],
333                    )
334                    .chain(cx.tcx.get_attrs_by_path(
335                        adt.did(),
336                        &[cx.sym.rust_gpu, cx.sym.vector, cx.sym.v1],
337                    )),
338            );
339
340            if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value)
341                && let Ok(spirv_type) =
342                    trans_intrinsic_type(cx, span, *self, args, intrinsic_type_attr)
343            {
344                return spirv_type;
345            }
346        }
347
348        // Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
349        // `ScalarPair`.
350        // There's a few layers that we go through here. First we inspect layout.backend_repr, then if relevant, layout.fields, etc.
351        match self.backend_repr {
352            _ if self.uninhabited => SpirvType::Adt {
353                def_id: def_id_for_spirv_type_adt(*self),
354                size: Some(Size::ZERO),
355                align: Align::from_bytes(0).unwrap(),
356                field_types: &[],
357                field_offsets: &[],
358                field_names: None,
359            }
360            .def_with_name(cx, span, TyLayoutNameKey::from(*self)),
361            BackendRepr::Scalar(scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
362            BackendRepr::ScalarPair(a, b) => {
363                // NOTE(eddyb) unlike `BackendRepr::Scalar`'s simpler newtype-unpacking
364                // behavior, `BackendRepr::ScalarPair` can be composed in two ways:
365                // * two `BackendRepr::Scalar` fields (and any number of ZST fields),
366                //   gets handled the same as a `struct { a, b }`, further below
367                // * an `BackendRepr::ScalarPair` field (and any number of ZST fields),
368                //   which requires more work to allow taking a reference to
369                //   that field, and there are two potential approaches:
370                //   1. wrapping that field's SPIR-V type in a single-field
371                //      `OpTypeStruct` - this has the disadvantage that GEPs
372                //      would have to inject an extra `0` field index, and other
373                //      field-related operations would also need additional work
374                //   2. reusing that field's SPIR-V type, instead of defining
375                //      a new one, offering the `(a, b)` shape `rustc_codegen_ssa`
376                //      expects, while letting noop pointercasts access the sole
377                //      `BackendRepr::ScalarPair` field - this is the approach taken here
378                let mut non_zst_fields = (0..self.fields.count())
379                    .map(|i| (i, self.field(cx, i)))
380                    .filter(|(_, field)| !field.is_zst());
381                let sole_non_zst_field = match (non_zst_fields.next(), non_zst_fields.next()) {
382                    (Some(field), None) => Some(field),
383                    _ => None,
384                };
385                if let Some((i, field)) = sole_non_zst_field {
386                    // Only unpack a newtype if the field and the newtype line up
387                    // perfectly, in every way that could potentially affect ABI.
388                    if self.fields.offset(i) == Size::ZERO
389                        && field.size == self.size
390                        && field.align.abi == self.align.abi
391                        && field.backend_repr.eq_up_to_validity(&self.backend_repr)
392                    {
393                        return field.spirv_type(span, cx);
394                    }
395                }
396
397                // Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
398                // recursive pointer types.
399                let a_offset = Size::ZERO;
400                let b_offset = a.primitive().size(cx).align_to(b.primitive().align(cx).abi);
401                let a = trans_scalar(cx, span, *self, a, a_offset);
402                let b = trans_scalar(cx, span, *self, b, b_offset);
403                let size = if self.is_unsized() {
404                    None
405                } else {
406                    Some(self.size)
407                };
408                // FIXME(eddyb) use `ArrayVec` here.
409                let mut field_names = Vec::new();
410                if let TyKind::Adt(adt, _) = self.ty.kind()
411                    && let Variants::Single { index } = self.variants
412                {
413                    for i in self.fields.index_by_increasing_offset() {
414                        let field = &adt.variants()[index].fields[FieldIdx::new(i)];
415                        field_names.push(field.name);
416                    }
417                }
418                SpirvType::Adt {
419                    def_id: def_id_for_spirv_type_adt(*self),
420                    size,
421                    align: self.align.abi,
422                    field_types: &[a, b],
423                    field_offsets: &[a_offset, b_offset],
424                    field_names: if field_names.len() == 2 {
425                        Some(&field_names)
426                    } else {
427                        None
428                    },
429                }
430                .def_with_name(cx, span, TyLayoutNameKey::from(*self))
431            }
432            BackendRepr::SimdVector { element, count } => {
433                let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
434                SpirvType::Vector {
435                    element: elem_spirv,
436                    count: count as u32,
437                    size: self.size,
438                    align: self.align.abi,
439                }
440                .def(span, cx)
441            }
442            BackendRepr::SimdScalableVector { .. } => cx
443                .tcx
444                .dcx()
445                .fatal("scalable vectors are not supported in SPIR-V backend"),
446            BackendRepr::Memory { sized: _ } => trans_aggregate(cx, span, *self),
447        }
448    }
449}
450
451/// Only pub for `LayoutTypeCodegenMethods::scalar_pair_element_backend_type`. Think about what you're
452/// doing before calling this.
453pub fn scalar_pair_element_backend_type<'tcx>(
454    cx: &CodegenCx<'tcx>,
455    span: Span,
456    ty: TyAndLayout<'tcx>,
457    index: usize,
458) -> Word {
459    let [a, b] = match ty.layout.backend_repr() {
460        BackendRepr::ScalarPair(a, b) => [a, b],
461        other => span_bug!(
462            span,
463            "scalar_pair_element_backend_type invalid abi: {:?}",
464            other
465        ),
466    };
467    let offset = match index {
468        0 => Size::ZERO,
469        1 => a.primitive().size(cx).align_to(b.primitive().align(cx).abi),
470        _ => unreachable!(),
471    };
472    trans_scalar(cx, span, ty, [a, b][index], offset)
473}
474
475/// A "scalar" is a basic building block: bools, ints, floats, pointers. (i.e. not something complex like a struct)
476/// A "scalar pair" is a bit of a strange concept: if there is a `fn f(x: (u32, u32))`, then what's preferred for
477/// performance is to compile that ABI to `f(x_1: u32, x_2: u32)`, i.e. splitting out the pair into their own arguments,
478/// and pretending that they're one unit. So, there's quite a bit of special handling around these scalar pairs to enable
479/// scenarios like that.
480/// I say it's "preferred", but spirv doesn't really care - only CPU ABIs really care here. However, following rustc's
481/// lead and doing what they want makes things go smoothly, so we'll implement it here too.
482fn trans_scalar<'tcx>(
483    cx: &CodegenCx<'tcx>,
484    span: Span,
485    ty: TyAndLayout<'tcx>,
486    scalar: Scalar,
487    offset: Size,
488) -> Word {
489    if scalar.is_bool() {
490        return SpirvType::Bool.def(span, cx);
491    }
492
493    match scalar.primitive() {
494        Primitive::Int(int_kind, signedness) => {
495            SpirvType::Integer(int_kind.size().bits() as u32, signedness).def(span, cx)
496        }
497        Primitive::Float(float_kind) => {
498            SpirvType::Float(float_kind.size().bits() as u32).def(span, cx)
499        }
500        Primitive::Pointer(_) => {
501            let pointee_ty = dig_scalar_pointee(cx, ty, offset);
502            // Pointers can be recursive. So, record what we're currently translating, and if we're already translating
503            // the same type, emit an OpTypeForwardPointer and use that ID.
504            if let Some(predefined_result) = cx
505                .type_cache
506                .recursive_pointee_cache
507                .begin(cx, span, pointee_ty)
508            {
509                predefined_result
510            } else {
511                let pointee = pointee_ty.spirv_type(span, cx);
512                cx.type_cache
513                    .recursive_pointee_cache
514                    .end(cx, span, pointee_ty, pointee)
515            }
516        }
517    }
518}
519
520// This is a really weird function, strap in...
521// So, rustc_codegen_ssa is designed around scalar pointers being opaque, you shouldn't know the type behind the
522// pointer. Unfortunately, that's impossible for us, we need to know the underlying pointee type for various reasons. In
523// some cases, this is pretty easy - if it's a TyKind::Ref, then the pointee will be the pointee of the ref (with
524// handling for wide pointers, etc.). Unfortunately, there's some pretty advanced processing going on in cx.layout_of:
525// for example, `ManuallyDrop<Result<ptr, ptr>>` has abi `ScalarPair`. This means that to figure out the pointee type,
526// we have to replicate the logic of cx.layout_of. Part of that is digging into types that are aggregates: for example,
527// ManuallyDrop<T> has a single field of type T. We "dig into" that field, and recurse, trying to find a base case that
528// we can handle, like TyKind::Ref.
529// If the above didn't make sense, please poke Ashley, it's probably easier to explain via conversation.
530fn dig_scalar_pointee<'tcx>(
531    cx: &CodegenCx<'tcx>,
532    layout: TyAndLayout<'tcx>,
533    offset: Size,
534) -> PointeeTy<'tcx> {
535    if let FieldsShape::Primitive = layout.fields {
536        assert_eq!(offset, Size::ZERO);
537        let pointee = match *layout.ty.kind() {
538            TyKind::Ref(_, pointee_ty, _) | TyKind::RawPtr(pointee_ty, _) => {
539                PointeeTy::Ty(cx.layout_of(pointee_ty))
540            }
541            TyKind::FnPtr(sig_tys, hdr) => PointeeTy::Fn(sig_tys.with(hdr)),
542            _ => bug!("Pointer is not `&T`, `*T` or `fn` pointer: {:#?}", layout),
543        };
544        return pointee;
545    }
546
547    let all_fields = (match &layout.variants {
548        Variants::Empty => 0..0,
549        Variants::Multiple { variants, .. } => 0..variants.len(),
550        Variants::Single { index } => {
551            let i = index.as_usize();
552            i..i + 1
553        }
554    })
555    .flat_map(|variant_idx| {
556        let variant = layout.for_variant(cx, VariantIdx::new(variant_idx));
557        (0..variant.fields.count()).map(move |field_idx| {
558            (
559                variant.field(cx, field_idx),
560                variant.fields.offset(field_idx),
561            )
562        })
563    });
564
565    let mut pointee = None;
566    for (field, field_offset) in all_fields {
567        if field.is_zst() {
568            continue;
569        }
570        if (field_offset..field_offset + field.size).contains(&offset) {
571            let new_pointee = dig_scalar_pointee(cx, field, offset - field_offset);
572            match pointee {
573                Some(old_pointee) if old_pointee != new_pointee => {
574                    cx.tcx.dcx().fatal(format!(
575                        "dig_scalar_pointee: unsupported Pointer with different \
576                         pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
577                    ));
578                }
579                _ => pointee = Some(new_pointee),
580            }
581        }
582    }
583    pointee.unwrap_or_else(|| {
584        bug!(
585            "field containing Pointer scalar at offset {:?} not found in {:#?}",
586            offset,
587            layout
588        )
589    })
590}
591
592// FIXME(eddyb) all `ty: TyAndLayout` variables should be `layout: TyAndLayout`,
593// the type is really more "Layout with Ty" (`.ty` field + `Deref`s to `Layout`).
594fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
595    fn create_zst<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
596        assert_eq!(ty.size, Size::ZERO);
597        SpirvType::Adt {
598            def_id: def_id_for_spirv_type_adt(ty),
599            size: Some(Size::ZERO),
600            align: ty.align.abi,
601            field_types: &[],
602            field_offsets: &[],
603            field_names: None,
604        }
605        .def_with_name(cx, span, TyLayoutNameKey::from(ty))
606    }
607    match ty.fields {
608        FieldsShape::Primitive => span_bug!(
609            span,
610            "trans_aggregate called for FieldsShape::Primitive layout {:#?}",
611            ty
612        ),
613        FieldsShape::Union(_) => {
614            assert!(!ty.is_unsized(), "{ty:#?}");
615
616            // Represent the `union` with its largest case, which should work
617            // for at least `MaybeUninit<T>` (which is between `T` and `()`),
618            // but also potentially some other ones as well.
619            // NOTE(eddyb) even if long-term this may become a byte array, that
620            // only works for "data types" and not "opaque handles" (images etc.).
621            let largest_case = (0..ty.fields.count())
622                .map(|i| (FieldIdx::from_usize(i), ty.field(cx, i)))
623                .max_by_key(|(_, case)| case.size);
624
625            if let Some((case_idx, case)) = largest_case {
626                if ty.align != case.align {
627                    // HACK(eddyb) mismatched alignment requires a wrapper `struct`.
628                    trans_struct_or_union(cx, span, ty, Some(case_idx))
629                } else {
630                    assert_eq!(ty.size, case.size);
631                    case.spirv_type(span, cx)
632                }
633            } else {
634                create_zst(cx, span, ty)
635            }
636        }
637        FieldsShape::Array { stride, count } => {
638            let element_type = ty.field(cx, 0).spirv_type(span, cx);
639            if ty.is_unsized() {
640                // There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
641                // However, I think rust disallows all these cases, so assert this here.
642                assert_eq!(count, 0);
643                SpirvType::RuntimeArray {
644                    element: element_type,
645                }
646                .def(span, cx)
647            } else if count == 0 {
648                // spir-v doesn't support zero-sized arrays
649                create_zst(cx, span, ty)
650            } else {
651                let count_const = cx.constant_u32(span, count as u32);
652                let element_spv = cx.lookup_type(element_type);
653                let stride_spv = element_spv
654                    .sizeof(cx)
655                    .expect("Unexpected unsized type in sized FieldsShape::Array")
656                    .align_to(element_spv.alignof(cx));
657                assert_eq!(stride_spv, stride);
658                SpirvType::Array {
659                    element: element_type,
660                    count: count_const,
661                }
662                .def(span, cx)
663            }
664        }
665        FieldsShape::Arbitrary { .. } => trans_struct_or_union(cx, span, ty, None),
666    }
667}
668
669// returns (field_offsets, size, align)
670pub fn auto_struct_layout(
671    cx: &CodegenCx<'_>,
672    field_types: &[Word],
673) -> (Vec<Size>, Option<Size>, Align) {
674    // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
675    let mut field_offsets = Vec::with_capacity(field_types.len());
676    let mut offset = Some(Size::ZERO);
677    let mut max_align = Align::from_bytes(0).unwrap();
678    for &field_type in field_types {
679        let spirv_type = cx.lookup_type(field_type);
680        let field_size = spirv_type.sizeof(cx);
681        let field_align = spirv_type.alignof(cx);
682        let this_offset = offset
683            .expect("Unsized values can only be the last field in a struct")
684            .align_to(field_align);
685
686        field_offsets.push(this_offset);
687        if field_align > max_align {
688            max_align = field_align;
689        }
690        offset = field_size.map(|size| this_offset + size);
691    }
692    (field_offsets, offset, max_align)
693}
694
695// see struct_llfields in librustc_codegen_llvm for implementation hints
696fn trans_struct_or_union<'tcx>(
697    cx: &CodegenCx<'tcx>,
698    span: Span,
699    ty: TyAndLayout<'tcx>,
700    union_case: Option<FieldIdx>,
701) -> Word {
702    let size = if ty.is_unsized() { None } else { Some(ty.size) };
703    let align = ty.align.abi;
704    // FIXME(eddyb) use `AccumulateVec`s just like `rustc` itself does.
705    let mut field_types = Vec::new();
706    let mut field_offsets = Vec::new();
707    let mut field_names = Vec::new();
708    for i in ty.fields.index_by_increasing_offset() {
709        if let Some(expected_field_idx) = union_case
710            && i != expected_field_idx.as_usize()
711        {
712            continue;
713        }
714
715        let field_ty = ty.field(cx, i);
716        field_types.push(field_ty.spirv_type(span, cx));
717        let offset = ty.fields.offset(i);
718        field_offsets.push(offset);
719        if let Variants::Single { index } = ty.variants {
720            if let TyKind::Adt(adt, _) = ty.ty.kind() {
721                let field = &adt.variants()[index].fields[FieldIdx::new(i)];
722                field_names.push(field.name);
723            } else {
724                // FIXME(eddyb) this looks like something that should exist in rustc.
725                field_names.push(Symbol::intern(&format!("{i}")));
726            }
727        } else {
728            if let TyKind::Adt(_, _) = ty.ty.kind() {
729            } else {
730                span_bug!(span, "Variants::Multiple not TyKind::Adt");
731            }
732            if i == 0 {
733                field_names.push(cx.sym.discriminant);
734            } else {
735                cx.tcx.dcx().fatal("Variants::Multiple has multiple fields")
736            }
737        };
738    }
739    SpirvType::Adt {
740        def_id: def_id_for_spirv_type_adt(ty),
741        size,
742        align,
743        field_types: &field_types,
744        field_offsets: &field_offsets,
745        field_names: Some(&field_names),
746    }
747    .def_with_name(cx, span, TyLayoutNameKey::from(ty))
748}
749
750/// Grab a `DefId` from the type if possible to avoid too much deduplication,
751/// which could result in one SPIR-V `OpType*` having many names
752/// (not in itself an issue, but it makes error reporting harder).
753fn def_id_for_spirv_type_adt(layout: TyAndLayout<'_>) -> Option<DefId> {
754    match *layout.ty.kind() {
755        TyKind::Adt(def, _) => Some(def.did()),
756        TyKind::Foreign(def_id) | TyKind::Closure(def_id, _) | TyKind::Coroutine(def_id, ..) => {
757            Some(def_id)
758        }
759        _ => None,
760    }
761}
762
763fn span_for_spirv_type_adt(cx: &CodegenCx<'_>, layout: TyAndLayout<'_>) -> Option<Span> {
764    def_id_for_spirv_type_adt(layout).map(|did| cx.tcx.def_span(did))
765}
766
767/// Minimal and cheaply comparable/hashable subset of the information contained
768/// in `TyLayout` that can be used to generate a name (assuming a nominal type).
769#[derive(Copy, Clone, PartialEq, Eq, Hash)]
770pub struct TyLayoutNameKey<'tcx> {
771    ty: Ty<'tcx>,
772    variant: Option<VariantIdx>,
773}
774
775impl<'tcx> From<TyAndLayout<'tcx>> for TyLayoutNameKey<'tcx> {
776    fn from(layout: TyAndLayout<'tcx>) -> Self {
777        TyLayoutNameKey {
778            ty: layout.ty,
779            variant: match layout.variants {
780                Variants::Single { index } => Some(index),
781                _ => None,
782            },
783        }
784    }
785}
786
787impl fmt::Display for TyLayoutNameKey<'_> {
788    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
789        write!(f, "{}", self.ty)?;
790        if let (TyKind::Adt(def, _), Some(index)) = (self.ty.kind(), self.variant)
791            && def.is_enum()
792            && !def.variants().is_empty()
793        {
794            write!(f, "::{}", def.variants()[index].name)?;
795        }
796        if let (TyKind::Coroutine(_, _), Some(index)) = (self.ty.kind(), self.variant) {
797            write!(f, "::{}", CoroutineArgs::variant_name(index))?;
798        }
799        Ok(())
800    }
801}
802
803fn trans_intrinsic_type<'tcx>(
804    cx: &CodegenCx<'tcx>,
805    span: Span,
806    ty: TyAndLayout<'tcx>,
807    args: GenericArgsRef<'tcx>,
808    intrinsic_type_attr: IntrinsicType,
809) -> Result<Word, ErrorGuaranteed> {
810    match intrinsic_type_attr {
811        IntrinsicType::GenericImageType => {
812            // see SpirvType::sizeof
813            if ty.size != Size::from_bytes(4) {
814                return Err(cx
815                    .tcx
816                    .dcx()
817                    .err("#[spirv(generic_image)] type must have size 4"));
818            }
819
820            // fn type_from_variant_discriminant<'tcx, P: FromPrimitive>(
821            //     cx: &CodegenCx<'tcx>,
822            //     const_: Const<'tcx>,
823            // ) -> P {
824            //     let adt_def = const_.ty.ty_adt_def().unwrap();
825            //     assert!(adt_def.is_enum());
826            //     let destructured = cx.tcx.destructure_const(TypingEnv::fully_monomorphized().and(const_));
827            //     let idx = destructured.variant.unwrap();
828            //     let value = const_.ty.discriminant_for_variant(cx.tcx, idx).unwrap().val as u64;
829            //     <_>::from_u64(value).unwrap()
830            // }
831
832            let sampled_type = match args.type_at(0).kind() {
833                TyKind::Int(int) => match int {
834                    IntTy::Isize => {
835                        SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, true)
836                            .def(span, cx)
837                    }
838                    IntTy::I8 => SpirvType::Integer(8, true).def(span, cx),
839                    IntTy::I16 => SpirvType::Integer(16, true).def(span, cx),
840                    IntTy::I32 => SpirvType::Integer(32, true).def(span, cx),
841                    IntTy::I64 => SpirvType::Integer(64, true).def(span, cx),
842                    IntTy::I128 => SpirvType::Integer(128, true).def(span, cx),
843                },
844                TyKind::Uint(uint) => match uint {
845                    UintTy::Usize => {
846                        SpirvType::Integer(cx.tcx.data_layout.pointer_size().bits() as u32, false)
847                            .def(span, cx)
848                    }
849                    UintTy::U8 => SpirvType::Integer(8, false).def(span, cx),
850                    UintTy::U16 => SpirvType::Integer(16, false).def(span, cx),
851                    UintTy::U32 => SpirvType::Integer(32, false).def(span, cx),
852                    UintTy::U64 => SpirvType::Integer(64, false).def(span, cx),
853                    UintTy::U128 => SpirvType::Integer(128, false).def(span, cx),
854                },
855                TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
856                TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
857                _ => {
858                    return Err(cx
859                        .tcx
860                        .dcx()
861                        .span_err(span, "Invalid sampled type to `Image`."));
862                }
863            };
864
865            // let dim: spirv::Dim = type_from_variant_discriminant(cx, args.const_at(1));
866            // let depth: u32 = type_from_variant_discriminant(cx, args.const_at(2));
867            // let arrayed: u32 = type_from_variant_discriminant(cx, args.const_at(3));
868            // let multisampled: u32 = type_from_variant_discriminant(cx, args.const_at(4));
869            // let sampled: u32 = type_from_variant_discriminant(cx, args.const_at(5));
870            // let image_format: spirv::ImageFormat =
871            //     type_from_variant_discriminant(cx, args.const_at(6));
872
873            trait FromScalarInt: Sized {
874                fn from_scalar_int(n: ScalarInt) -> Option<Self>;
875            }
876
877            impl FromScalarInt for u32 {
878                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
879                    Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
880                }
881            }
882
883            impl FromScalarInt for Dim {
884                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
885                    Dim::from_u32(u32::from_scalar_int(n)?)
886                }
887            }
888
889            impl FromScalarInt for ImageFormat {
890                fn from_scalar_int(n: ScalarInt) -> Option<Self> {
891                    ImageFormat::from_u32(u32::from_scalar_int(n)?)
892                }
893            }
894
895            fn const_int_value<'tcx, P: FromScalarInt>(
896                cx: &CodegenCx<'tcx>,
897                const_: Const<'tcx>,
898            ) -> Result<P, ErrorGuaranteed> {
899                let ty::Value {
900                    ty: const_ty,
901                    valtree: const_val,
902                } = const_.to_value();
903                assert!(const_ty.is_integral());
904                const_val
905                    .try_to_scalar()
906                    .and_then(|scalar| scalar.try_to_scalar_int().ok())
907                    .and_then(P::from_scalar_int)
908                    .ok_or_else(|| {
909                        cx.tcx
910                            .dcx()
911                            .err(format!("invalid value for Image const generic: {const_}"))
912                    })
913            }
914
915            let dim = const_int_value(cx, args.const_at(1))?;
916            let depth = const_int_value(cx, args.const_at(2))?;
917            let arrayed = const_int_value(cx, args.const_at(3))?;
918            let multisampled = const_int_value(cx, args.const_at(4))?;
919            let sampled = const_int_value(cx, args.const_at(5))?;
920            let image_format = const_int_value(cx, args.const_at(6))?;
921
922            let ty = SpirvType::Image {
923                sampled_type,
924                dim,
925                depth,
926                arrayed,
927                multisampled,
928                sampled,
929                image_format,
930            };
931            Ok(ty.def(span, cx))
932        }
933        IntrinsicType::Sampler => {
934            // see SpirvType::sizeof
935            if ty.size != Size::from_bytes(4) {
936                return Err(cx.tcx.dcx().err("#[spirv(sampler)] type must have size 4"));
937            }
938            Ok(SpirvType::Sampler.def(span, cx))
939        }
940        IntrinsicType::AccelerationStructureKhr => {
941            Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
942        }
943        IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
944        IntrinsicType::SampledImage => {
945            // see SpirvType::sizeof
946            if ty.size != Size::from_bytes(4) {
947                return Err(cx
948                    .tcx
949                    .dcx()
950                    .err("#[spirv(sampled_image)] type must have size 4"));
951            }
952
953            // We use a generic to indicate the underlying image type of the sampled image.
954            // The spirv type of it will be generated by querying the type of the first generic.
955            if let Some(image_ty) = args.types().next() {
956                // TODO: enforce that the generic param is an image type?
957                let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
958                Ok(SpirvType::SampledImage { image_type }.def(span, cx))
959            } else {
960                Err(cx
961                    .tcx
962                    .dcx()
963                    .err("#[spirv(sampled_image)] type must have a generic image type"))
964            }
965        }
966        IntrinsicType::RuntimeArray => {
967            if ty.size != Size::from_bytes(4) {
968                return Err(cx
969                    .tcx
970                    .dcx()
971                    .err("#[spirv(runtime_array)] type must have size 4"));
972            }
973
974            // We use a generic param to indicate the underlying element type.
975            // The SPIR-V element type will be generated from the first generic param.
976            if let Some(elem_ty) = args.types().next() {
977                Ok(SpirvType::RuntimeArray {
978                    element: cx.layout_of(elem_ty).spirv_type(span, cx),
979                }
980                .def(span, cx))
981            } else {
982                Err(cx
983                    .tcx
984                    .dcx()
985                    .err("#[spirv(runtime_array)] type must have a generic element type"))
986            }
987        }
988        IntrinsicType::TypedBuffer => {
989            if ty.size != Size::from_bytes(4) {
990                return Err(cx
991                    .tcx
992                    .sess
993                    .dcx()
994                    .err("#[spirv(typed_buffer)] type must have size 4"));
995            }
996
997            // We use a generic param to indicate the underlying data type.
998            // The SPIR-V data type will be generated from the first generic param.
999            if let Some(data_ty) = args.types().next() {
1000                // HACK(eddyb) this should be a *pointer* to an "interface block",
1001                // but SPIR-V screwed up and used no explicit indirection for the
1002                // descriptor indexing case, and instead made a `RuntimeArray` of
1003                // `InterfaceBlock`s be an "array of typed buffer resources".
1004                Ok(SpirvType::InterfaceBlock {
1005                    inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
1006                }
1007                .def(span, cx))
1008            } else {
1009                Err(cx
1010                    .tcx
1011                    .sess
1012                    .dcx()
1013                    .err("#[spirv(typed_buffer)] type must have a generic data type"))
1014            }
1015        }
1016        IntrinsicType::Matrix => {
1017            let span = span_for_spirv_type_adt(cx, ty).unwrap();
1018            let err_attr_name = "`#[spirv(matrix)]`";
1019            let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
1020            match cx.lookup_type(element) {
1021                SpirvType::Vector { .. } => (),
1022                ty => {
1023                    return Err(cx
1024                        .tcx
1025                        .dcx()
1026                        .struct_span_err(
1027                            span,
1028                            format!("{err_attr_name} type fields must all be vectors"),
1029                        )
1030                        .with_note(format!("field type is {}", ty.debug(element, cx)))
1031                        .emit());
1032                }
1033            }
1034            Ok(SpirvType::Matrix { element, count }.def(span, cx))
1035        }
1036        IntrinsicType::Vector => {
1037            let span = span_for_spirv_type_adt(cx, ty).unwrap();
1038            let err_attr_name = "`#[spirv(vector)]`";
1039            let (element, count) = trans_glam_like_struct(cx, span, ty, args, err_attr_name)?;
1040            match cx.lookup_type(element) {
1041                SpirvType::Bool | SpirvType::Float { .. } | SpirvType::Integer { .. } => (),
1042                ty => {
1043                    return Err(cx
1044                        .tcx
1045                        .dcx()
1046                        .struct_span_err(
1047                            span,
1048                            format!(
1049                                "{err_attr_name} type fields must all be floats, integers or bools"
1050                            ),
1051                        )
1052                        .with_note(format!("field type is {}", ty.debug(element, cx)))
1053                        .emit());
1054                }
1055            }
1056            Ok(SpirvType::Vector {
1057                element,
1058                count,
1059                size: ty.size,
1060                align: ty.align.abi,
1061            }
1062            .def(span, cx))
1063        }
1064    }
1065}
1066
1067/// A struct with multiple fields of the same kind.
1068/// Used for `#[spirv(vector)]` and `#[spirv(matrix)]`.
1069fn trans_glam_like_struct<'tcx>(
1070    cx: &CodegenCx<'tcx>,
1071    span: Span,
1072    ty: TyAndLayout<'tcx>,
1073    args: GenericArgsRef<'tcx>,
1074    err_attr_name: &str,
1075) -> Result<(Word, u32), ErrorGuaranteed> {
1076    let tcx = cx.tcx;
1077    if let Some(adt) = ty.ty.ty_adt_def()
1078        && adt.is_struct()
1079    {
1080        let (count, element) = adt
1081            .non_enum_variant()
1082            .fields
1083            .iter()
1084            .map(|f| f.ty(tcx, args))
1085            .dedup_with_count()
1086            .exactly_one()
1087            .map_err(|_e| {
1088                tcx.dcx().span_err(
1089                    span,
1090                    format!("{err_attr_name} member types must all be the same"),
1091                )
1092            })?;
1093
1094        let element = cx.layout_of(element);
1095        let element_word = element.spirv_type(span, cx);
1096        let count = u32::try_from(count)
1097            .ok()
1098            .filter(|count| 2 <= *count && *count <= 4)
1099            .ok_or_else(|| {
1100                tcx.dcx()
1101                    .span_err(span, format!("{err_attr_name} must have 2, 3 or 4 members"))
1102            })?;
1103
1104        for i in 0..ty.fields.count() {
1105            let expected = element.size.checked_mul(i as u64, cx).unwrap();
1106            let actual = ty.fields.offset(i);
1107            if actual != expected {
1108                let name: &str = adt
1109                    .non_enum_variant()
1110                    .fields
1111                    .get(FieldIdx::from(i))
1112                    .unwrap()
1113                    .name
1114                    .as_str();
1115                tcx.dcx().span_fatal(
1116                    span,
1117                    format!(
1118                        "Unexpected layout for {err_attr_name} annotated struct: \
1119                    Expected member `{name}` at offset {expected:?}, but was at {actual:?}"
1120                    ),
1121                )
1122            }
1123        }
1124
1125        Ok((element_word, count))
1126    } else {
1127        Err(tcx
1128            .dcx()
1129            .span_err(span, format!("{err_attr_name} type must be a struct")))
1130    }
1131}