1use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
3
4use super::CodegenCx;
5use crate::abi::ConvSpirvType;
6use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind};
7use crate::spirv_type::SpirvType;
8use itertools::Itertools as _;
9use rspirv::spirv::Word;
10use rustc_abi::{self as abi, AddressSpace, Float, HasDataLayout, Integer, Primitive, Size};
11use rustc_codegen_ssa::traits::{ConstCodegenMethods, MiscCodegenMethods, StaticCodegenMethods};
12use rustc_middle::mir::interpret::{AllocError, ConstAllocation, GlobalAlloc, Scalar, alloc_range};
13use rustc_middle::ty::layout::LayoutOf;
14use rustc_span::{DUMMY_SP, Span};
15
16impl<'tcx> CodegenCx<'tcx> {
17 pub fn def_constant(&self, ty: Word, val: SpirvConst<'_, 'tcx>) -> SpirvValue {
18 self.builder.def_constant_cx(ty, val, self)
19 }
20
21 pub fn constant_u8(&self, span: Span, val: u8) -> SpirvValue {
22 self.constant_int_from_native_unsigned(span, val)
23 }
24
25 pub fn constant_i8(&self, span: Span, val: i8) -> SpirvValue {
26 self.constant_int_from_native_signed(span, val)
27 }
28
29 pub fn constant_i16(&self, span: Span, val: i16) -> SpirvValue {
30 self.constant_int_from_native_signed(span, val)
31 }
32
33 pub fn constant_u16(&self, span: Span, val: u16) -> SpirvValue {
34 self.constant_int_from_native_unsigned(span, val)
35 }
36
37 pub fn constant_i32(&self, span: Span, val: i32) -> SpirvValue {
38 self.constant_int_from_native_signed(span, val)
39 }
40
41 pub fn constant_u32(&self, span: Span, val: u32) -> SpirvValue {
42 self.constant_int_from_native_unsigned(span, val)
43 }
44
45 pub fn constant_i64(&self, span: Span, val: i64) -> SpirvValue {
46 self.constant_int_from_native_signed(span, val)
47 }
48
49 pub fn constant_u64(&self, span: Span, val: u64) -> SpirvValue {
50 self.constant_int_from_native_unsigned(span, val)
51 }
52
53 pub fn constant_u128(&self, span: Span, val: u128) -> SpirvValue {
54 self.constant_int_from_native_unsigned(span, val)
55 }
56
57 fn constant_int_from_native_unsigned(&self, span: Span, val: impl Into<u128>) -> SpirvValue {
58 let size = Size::from_bytes(std::mem::size_of_val(&val));
59 let ty = SpirvType::Integer(size.bits() as u32, false).def(span, self);
60 self.constant_int(ty, val.into())
61 }
62
63 fn constant_int_from_native_signed(&self, span: Span, val: impl Into<i128>) -> SpirvValue {
64 let size = Size::from_bytes(std::mem::size_of_val(&val));
65 let ty = SpirvType::Integer(size.bits() as u32, true).def(span, self);
66 self.constant_int(ty, val.into() as u128)
67 }
68
69 pub fn constant_int(&self, ty: Word, val: u128) -> SpirvValue {
70 self.def_constant(ty, SpirvConst::Scalar(val))
71 }
72
73 pub fn constant_f32(&self, span: Span, val: f32) -> SpirvValue {
74 let ty = SpirvType::Float(32).def(span, self);
75 self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into()))
76 }
77
78 pub fn constant_f64(&self, span: Span, val: f64) -> SpirvValue {
79 let ty = SpirvType::Float(64).def(span, self);
80 self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into()))
81 }
82
83 pub fn constant_float(&self, ty: Word, val: f64) -> SpirvValue {
84 match self.lookup_type(ty) {
85 SpirvType::Float(32) => {
87 self.def_constant(ty, SpirvConst::Scalar((val as f32).to_bits().into()))
88 }
89 SpirvType::Float(64) => self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into())),
90 other => self.tcx.dcx().fatal(format!(
91 "constant_float does not support type {}",
92 other.debug(ty, self)
93 )),
94 }
95 }
96
97 pub fn constant_bool(&self, span: Span, val: bool) -> SpirvValue {
98 let ty = SpirvType::Bool.def(span, self);
99 self.def_constant(ty, SpirvConst::Scalar(val as u128))
100 }
101
102 pub fn constant_composite(&self, ty: Word, fields: impl Iterator<Item = Word>) -> SpirvValue {
103 self.def_constant(ty, SpirvConst::Composite(&fields.collect::<Vec<_>>()))
105 }
106
107 pub fn constant_null(&self, ty: Word) -> SpirvValue {
108 self.def_constant(ty, SpirvConst::Null)
109 }
110
111 pub fn undef(&self, ty: Word) -> SpirvValue {
112 self.def_constant(ty, SpirvConst::Undef)
113 }
114}
115
116impl ConstCodegenMethods for CodegenCx<'_> {
117 fn const_null(&self, t: Self::Type) -> Self::Value {
118 self.constant_null(t)
119 }
120 fn const_undef(&self, ty: Self::Type) -> Self::Value {
121 self.undef(ty)
122 }
123 fn const_poison(&self, ty: Self::Type) -> Self::Value {
124 self.const_undef(ty)
126 }
127 fn const_int(&self, t: Self::Type, i: i64) -> Self::Value {
128 self.constant_int(t, i as u128)
129 }
130 fn const_uint(&self, t: Self::Type, i: u64) -> Self::Value {
131 self.constant_int(t, i.into())
132 }
133 fn const_uint_big(&self, t: Self::Type, i: u128) -> Self::Value {
134 self.constant_int(t, i)
135 }
136 fn const_bool(&self, val: bool) -> Self::Value {
137 self.constant_bool(DUMMY_SP, val)
138 }
139 fn const_i8(&self, i: i8) -> Self::Value {
140 self.constant_i8(DUMMY_SP, i)
141 }
142 fn const_i16(&self, i: i16) -> Self::Value {
143 self.constant_i16(DUMMY_SP, i)
144 }
145 fn const_i32(&self, i: i32) -> Self::Value {
146 self.constant_i32(DUMMY_SP, i)
147 }
148 fn const_i64(&self, i: i64) -> Self::Value {
149 self.constant_i64(DUMMY_SP, i)
150 }
151 fn const_u8(&self, i: u8) -> Self::Value {
152 self.constant_u8(DUMMY_SP, i)
153 }
154 fn const_u32(&self, i: u32) -> Self::Value {
155 self.constant_u32(DUMMY_SP, i)
156 }
157 fn const_u64(&self, i: u64) -> Self::Value {
158 self.constant_u64(DUMMY_SP, i)
159 }
160 fn const_u128(&self, i: u128) -> Self::Value {
161 let ty = SpirvType::Integer(128, false).def(DUMMY_SP, self);
162 self.const_uint_big(ty, i)
163 }
164 fn const_usize(&self, i: u64) -> Self::Value {
165 let ptr_size = self.tcx.data_layout.pointer_size().bits() as u32;
166 let t = SpirvType::Integer(ptr_size, false).def(DUMMY_SP, self);
167 self.constant_int(t, i.into())
168 }
169 fn const_real(&self, t: Self::Type, val: f64) -> Self::Value {
170 self.constant_float(t, val)
171 }
172
173 fn const_str(&self, s: &str) -> (Self::Value, Self::Value) {
174 let len = s.len();
175 let str_ty = self
176 .layout_of(self.tcx.types.str_)
177 .spirv_type(DUMMY_SP, self);
178 (
179 self.def_constant(
180 self.type_ptr_to(str_ty),
181 SpirvConst::PtrTo {
182 pointee: self
183 .constant_composite(
184 str_ty,
185 s.bytes().map(|b| self.const_u8(b).def_cx(self)),
186 )
187 .def_cx(self),
188 },
189 ),
190 self.const_usize(len as u64),
191 )
192 }
193 fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value {
194 let field_types = elts.iter().map(|f| f.ty).collect::<Vec<_>>();
197 let (field_offsets, size, align) = crate::abi::auto_struct_layout(self, &field_types);
198 let struct_ty = SpirvType::Adt {
199 def_id: None,
200 size,
201 align,
202 field_types: &field_types,
203 field_offsets: &field_offsets,
204 field_names: None,
205 }
206 .def(DUMMY_SP, self);
207 self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
208 }
209 fn const_vector(&self, elts: &[Self::Value]) -> Self::Value {
210 let vector_ty = SpirvType::simd_vector(
211 self,
212 DUMMY_SP,
213 self.lookup_type(elts[0].ty),
214 elts.len() as u32,
215 )
216 .def(DUMMY_SP, self);
217 self.constant_composite(vector_ty, elts.iter().map(|elt| elt.def_cx(self)))
218 }
219
220 fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> {
221 self.builder.lookup_const_scalar(v)?.try_into().ok()
222 }
223 fn const_to_opt_u128(&self, v: Self::Value, _sign_ext: bool) -> Option<u128> {
226 self.builder.lookup_const_scalar(v)
227 }
228
229 fn scalar_to_backend(
230 &self,
231 scalar: Scalar,
232 layout: abi::Scalar,
233 ty: Self::Type,
234 ) -> Self::Value {
235 match scalar {
236 Scalar::Int(int) => {
237 assert_eq!(int.size(), layout.primitive().size(self));
238 let data = int.to_uint(int.size());
239
240 if let Primitive::Pointer(_) = layout.primitive() {
241 if data == 0 {
242 self.constant_null(ty)
243 } else {
244 let result = self.undef(ty);
245 self.zombie_no_span(
246 result.def_cx(self),
247 "pointer has non-null integer address",
248 );
249 result
250 }
251 } else {
252 self.def_constant(ty, SpirvConst::Scalar(data))
253 }
254 }
255 Scalar::Ptr(ptr, _) => {
256 let (prov, offset) = ptr.prov_and_relative_offset();
257 let alloc_id = prov.alloc_id();
258 let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) {
259 GlobalAlloc::Memory(alloc) => {
260 match self.lookup_type(ty) {
261 SpirvType::Pointer { .. } => {}
262 other => self.tcx.dcx().fatal(format!(
263 "GlobalAlloc::Memory type not implemented: {}",
264 other.debug(ty, self)
265 )),
266 }
267 let value = self.static_addr_of(alloc, None);
268 (value, AddressSpace::ZERO)
269 }
270 GlobalAlloc::Function { instance } => (
271 self.get_fn_addr(instance),
272 self.data_layout().instruction_address_space,
273 ),
274 GlobalAlloc::VTable(vty, dyn_ty) => {
275 let alloc = self
276 .tcx
277 .global_alloc(self.tcx.vtable_allocation((
278 vty,
279 dyn_ty.principal().map(|principal| {
280 self.tcx.instantiate_bound_regions_with_erased(principal)
281 }),
282 )))
283 .unwrap_memory();
284 match self.lookup_type(ty) {
285 SpirvType::Pointer { .. } => {}
286 other => self.tcx.dcx().fatal(format!(
287 "GlobalAlloc::VTable type not implemented: {}",
288 other.debug(ty, self)
289 )),
290 }
291 let value = self.static_addr_of(alloc, None);
292 (value, AddressSpace::ZERO)
293 }
294 GlobalAlloc::Static(def_id) => {
295 assert!(self.tcx.is_static(def_id));
296 assert!(!self.tcx.is_thread_local_static(def_id));
297 (self.get_static(def_id), AddressSpace::ZERO)
298 }
299 GlobalAlloc::TypeId { .. } => {
300 return if offset.bytes() == 0 {
301 self.constant_null(ty)
302 } else {
303 let result = self.undef(ty);
304 self.zombie_no_span(
305 result.def_cx(self),
306 "pointer has non-null integer address",
307 );
308 result
309 };
310 }
311 };
312 self.const_bitcast(self.const_ptr_byte_offset(base_addr, offset), ty)
313 }
314 }
315 }
316
317 fn const_ptr_byte_offset(&self, val: Self::Value, offset: Size) -> Self::Value {
318 if offset == Size::ZERO {
319 val
320 } else {
321 let result = val;
324 self.zombie_no_span(result.def_cx(self), "const_ptr_byte_offset");
325 result
326 }
327 }
328}
329
330impl<'tcx> CodegenCx<'tcx> {
331 pub(crate) fn const_data_from_alloc(&self, alloc: ConstAllocation<'_>) -> SpirvValue {
336 let alloc = self.tcx.lift(alloc).unwrap();
340
341 let void_type = SpirvType::Void.def(DUMMY_SP, self);
342 self.def_constant(void_type, SpirvConst::ConstDataFromAlloc(alloc))
343 }
344
345 pub fn const_bitcast(&self, val: SpirvValue, ty: Word) -> SpirvValue {
346 if let SpirvValueKind::IllegalConst(_) = val.kind
349 && let Some(SpirvConst::PtrTo { pointee }) = self.builder.lookup_const(val)
350 && let Some(SpirvConst::ConstDataFromAlloc(alloc)) =
351 self.builder.lookup_const_by_id(pointee)
352 && let SpirvType::Pointer { pointee } = self.lookup_type(ty)
353 && let Some(init) = self.try_read_from_const_alloc(alloc, pointee)
354 {
355 return self.def_constant(
356 ty,
357 SpirvConst::PtrTo {
358 pointee: init.def_cx(self),
359 },
360 );
361 }
362
363 if val.ty == ty {
364 val
365 } else {
366 let result = val.def_cx(self).with_type(ty);
369 self.zombie_no_span(result.def_cx(self), "const_bitcast");
370 result
371 }
372 }
373
374 pub fn primitive_to_scalar(&self, value: Primitive) -> abi::Scalar {
377 let bits = value.size(self.data_layout()).bits();
378 assert!(bits <= 128);
379 abi::Scalar::Initialized {
380 value,
381 valid_range: abi::WrappingRange {
382 start: 0,
383 end: (!0 >> (128 - bits)),
384 },
385 }
386 }
387
388 pub fn try_read_from_const_alloc(
393 &self,
394 alloc: ConstAllocation<'tcx>,
395 ty: Word,
396 ) -> Option<SpirvValue> {
397 let (result, read_size) = self.read_from_const_alloc_at(alloc, ty, Size::ZERO);
398 (read_size == alloc.inner().size()).then_some(result)
399 }
400
401 #[tracing::instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), offset))]
406 fn read_from_const_alloc_at(
407 &self,
408 alloc: ConstAllocation<'tcx>,
409 ty: Word,
410 offset: Size,
411 ) -> (SpirvValue, Size) {
412 let ty_def = self.lookup_type(ty);
413 match ty_def {
414 SpirvType::Bool
415 | SpirvType::Integer(..)
416 | SpirvType::Float(_)
417 | SpirvType::Pointer { .. } => {
418 let size = ty_def.sizeof(self).unwrap();
419 let primitive = match ty_def {
420 SpirvType::Bool => Primitive::Int(Integer::fit_unsigned(0), false),
421 SpirvType::Integer(int_size, int_signedness) => Primitive::Int(
422 match int_size {
423 8 => Integer::I8,
424 16 => Integer::I16,
425 32 => Integer::I32,
426 64 => Integer::I64,
427 128 => Integer::I128,
428 other => {
429 self.tcx
430 .dcx()
431 .fatal(format!("invalid size for integer: {other}"));
432 }
433 },
434 int_signedness,
435 ),
436 SpirvType::Float(float_size) => Primitive::Float(match float_size {
437 16 => Float::F16,
438 32 => Float::F32,
439 64 => Float::F64,
440 128 => Float::F128,
441 other => {
442 self.tcx
443 .dcx()
444 .fatal(format!("invalid size for float: {other}"));
445 }
446 }),
447 SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::ZERO),
448 _ => unreachable!(),
449 };
450
451 let range = alloc_range(offset, size);
452 let read_provenance = matches!(primitive, Primitive::Pointer(_));
453
454 let mut primitive = primitive;
455 let mut read_result = alloc.inner().read_scalar(self, range, read_provenance);
456
457 if read_result.is_err()
461 && !read_provenance
462 && let read_ptr_result @ Ok(Scalar::Ptr(ptr, _)) = alloc
463 .inner()
464 .read_scalar(self, range, true)
465 {
466 let (prov, _offset) = ptr.prov_and_relative_offset();
467 primitive = Primitive::Pointer(
468 self.tcx.global_alloc(prov.alloc_id()).address_space(self),
469 );
470 read_result = read_ptr_result;
471 }
472
473 let scalar_or_zombie = match read_result {
474 Ok(scalar) => {
475 Ok(self.scalar_to_backend(scalar, self.primitive_to_scalar(primitive), ty))
476 }
477
478 Err(err) => match err {
481 AllocError::InvalidUninitBytes(_) => {
485 let uninit_range = alloc
486 .inner()
487 .init_mask()
488 .is_range_initialized(range)
489 .unwrap_err();
490 let uninit_size = {
491 let [start, end] = [uninit_range.start, uninit_range.end()]
492 .map(|x| x.clamp(range.start, range.end()));
493 end - start
494 };
495 if uninit_size == size {
496 Ok(self.undef(ty))
497 } else {
498 Err(format!(
499 "overlaps {} uninitialized bytes",
500 uninit_size.bytes()
501 ))
502 }
503 }
504 AllocError::ReadPointerAsInt(_) => Err("overlaps pointer bytes".into()),
505 AllocError::ReadPartialPointer(_) => {
506 Err("partially overlaps another pointer".into())
507 }
508
509 AllocError::ScalarSizeMismatch(_) => {
512 Err(format!("unrecognized `AllocError::{err:?}`"))
513 }
514 },
515 };
516 let result = scalar_or_zombie.unwrap_or_else(|reason| {
517 let result = self.undef(ty);
518 self.zombie_no_span(
519 result.def_cx(self),
520 &format!("unsupported `{}` constant: {reason}", self.debug_type(ty),),
521 );
522 result
523 });
524 (result, size)
525 }
526 SpirvType::Adt {
527 field_types,
528 field_offsets,
529 ..
530 } => {
531 let mut tail_read_range = ..Size::ZERO;
534 let result = self.constant_composite(
535 ty,
536 field_types
537 .iter()
538 .zip_eq(field_offsets.iter())
539 .map(|(&f_ty, &f_offset)| {
540 let (f, f_size) =
541 self.read_from_const_alloc_at(alloc, f_ty, offset + f_offset);
542 tail_read_range.end =
543 tail_read_range.end.max(offset + f_offset + f_size);
544 f.def_cx(self)
545 }),
546 );
547
548 let ty_size = ty_def.sizeof(self);
549
550 if let Some(ty_size) = ty_size
552 && let Some(tail_gap) = (ty_size.bytes())
553 .checked_sub(tail_read_range.end.align_to(ty_def.alignof(self)).bytes())
554 && tail_gap > 0
555 {
556 self.zombie_no_span(
557 result.def_cx(self),
558 &format!(
559 "undersized `{}` constant (at least {tail_gap} bytes may be missing)",
560 self.debug_type(ty)
561 ),
562 );
563 }
564
565 (result, ty_size.unwrap_or(tail_read_range.end))
566 }
567 SpirvType::Vector { element, .. }
568 | SpirvType::Matrix { element, .. }
569 | SpirvType::Array { element, .. }
570 | SpirvType::RuntimeArray { element } => {
571 let stride = self.lookup_type(element).sizeof(self).unwrap();
572
573 let count = match ty_def {
574 SpirvType::Vector { count, .. } | SpirvType::Matrix { count, .. } => {
575 u64::from(count)
576 }
577 SpirvType::Array { count, .. } => {
578 u64::try_from(self.builder.lookup_const_scalar(count).unwrap()).unwrap()
579 }
580 SpirvType::RuntimeArray { .. } => {
581 (alloc.inner().size() - offset).bytes() / stride.bytes()
582 }
583 _ => unreachable!(),
584 };
585
586 let result = self.constant_composite(
587 ty,
588 (0..count).map(|i| {
589 let (e, e_size) =
590 self.read_from_const_alloc_at(alloc, element, offset + i * stride);
591 assert_eq!(e_size, stride);
592 e.def_cx(self)
593 }),
594 );
595
596 let read_size = (count * stride).align_to(ty_def.alignof(self));
600
601 if let Some(ty_size) = ty_def.sizeof(self) {
602 assert_eq!(read_size, ty_size);
603 }
604
605 if let SpirvType::RuntimeArray { .. } = ty_def {
606 self.zombie_no_span(
611 result.def_cx(self),
612 &format!("unsupported unsized `{}` constant", self.debug_type(ty)),
613 );
614 }
615
616 (result, read_size)
617 }
618
619 SpirvType::Void
620 | SpirvType::Function { .. }
621 | SpirvType::Image { .. }
622 | SpirvType::Sampler
623 | SpirvType::SampledImage { .. }
624 | SpirvType::InterfaceBlock { .. }
625 | SpirvType::AccelerationStructureKhr
626 | SpirvType::RayQueryKhr => {
627 let result = self.undef(ty);
628 self.zombie_no_span(
629 result.def_cx(self),
630 &format!(
631 "cannot reinterpret Rust constant data as a `{}` value",
632 self.debug_type(ty)
633 ),
634 );
635 (result, ty_def.sizeof(self).unwrap_or(Size::ZERO))
636 }
637 }
638 }
639}