1use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
3
4use super::CodegenCx;
5use crate::abi::ConvSpirvType;
6use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind};
7use crate::spirv_type::SpirvType;
8use itertools::Itertools as _;
9use rspirv::spirv::Word;
10use rustc_abi::{self as abi, AddressSpace, Float, HasDataLayout, Integer, Primitive, Size};
11use rustc_codegen_ssa::traits::{ConstCodegenMethods, MiscCodegenMethods, StaticCodegenMethods};
12use rustc_middle::mir::interpret::{AllocError, ConstAllocation, GlobalAlloc, Scalar, alloc_range};
13use rustc_middle::ty::layout::LayoutOf;
14use rustc_span::{DUMMY_SP, Span};
15
16impl<'tcx> CodegenCx<'tcx> {
17 pub fn def_constant(&self, ty: Word, val: SpirvConst<'_, 'tcx>) -> SpirvValue {
18 self.builder.def_constant_cx(ty, val, self)
19 }
20
21 pub fn constant_u8(&self, span: Span, val: u8) -> SpirvValue {
22 self.constant_int_from_native_unsigned(span, val)
23 }
24
25 pub fn constant_i8(&self, span: Span, val: i8) -> SpirvValue {
26 self.constant_int_from_native_signed(span, val)
27 }
28
29 pub fn constant_i16(&self, span: Span, val: i16) -> SpirvValue {
30 self.constant_int_from_native_signed(span, val)
31 }
32
33 pub fn constant_u16(&self, span: Span, val: u16) -> SpirvValue {
34 self.constant_int_from_native_unsigned(span, val)
35 }
36
37 pub fn constant_i32(&self, span: Span, val: i32) -> SpirvValue {
38 self.constant_int_from_native_signed(span, val)
39 }
40
41 pub fn constant_u32(&self, span: Span, val: u32) -> SpirvValue {
42 self.constant_int_from_native_unsigned(span, val)
43 }
44
45 pub fn constant_i64(&self, span: Span, val: i64) -> SpirvValue {
46 self.constant_int_from_native_signed(span, val)
47 }
48
49 pub fn constant_u64(&self, span: Span, val: u64) -> SpirvValue {
50 self.constant_int_from_native_unsigned(span, val)
51 }
52
53 fn constant_int_from_native_unsigned(&self, span: Span, val: impl Into<u128>) -> SpirvValue {
54 let size = Size::from_bytes(std::mem::size_of_val(&val));
55 let ty = SpirvType::Integer(size.bits() as u32, false).def(span, self);
56 self.constant_int(ty, val.into())
57 }
58
59 fn constant_int_from_native_signed(&self, span: Span, val: impl Into<i128>) -> SpirvValue {
60 let size = Size::from_bytes(std::mem::size_of_val(&val));
61 let ty = SpirvType::Integer(size.bits() as u32, true).def(span, self);
62 self.constant_int(ty, val.into() as u128)
63 }
64
65 pub fn constant_int(&self, ty: Word, val: u128) -> SpirvValue {
66 self.def_constant(ty, SpirvConst::Scalar(val))
67 }
68
69 pub fn constant_f32(&self, span: Span, val: f32) -> SpirvValue {
70 let ty = SpirvType::Float(32).def(span, self);
71 self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into()))
72 }
73
74 pub fn constant_f64(&self, span: Span, val: f64) -> SpirvValue {
75 let ty = SpirvType::Float(64).def(span, self);
76 self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into()))
77 }
78
79 pub fn constant_float(&self, ty: Word, val: f64) -> SpirvValue {
80 match self.lookup_type(ty) {
81 SpirvType::Float(32) => {
83 self.def_constant(ty, SpirvConst::Scalar((val as f32).to_bits().into()))
84 }
85 SpirvType::Float(64) => self.def_constant(ty, SpirvConst::Scalar(val.to_bits().into())),
86 other => self.tcx.dcx().fatal(format!(
87 "constant_float does not support type {}",
88 other.debug(ty, self)
89 )),
90 }
91 }
92
93 pub fn constant_bool(&self, span: Span, val: bool) -> SpirvValue {
94 let ty = SpirvType::Bool.def(span, self);
95 self.def_constant(ty, SpirvConst::Scalar(val as u128))
96 }
97
98 pub fn constant_composite(&self, ty: Word, fields: impl Iterator<Item = Word>) -> SpirvValue {
99 self.def_constant(ty, SpirvConst::Composite(&fields.collect::<Vec<_>>()))
101 }
102
103 pub fn constant_null(&self, ty: Word) -> SpirvValue {
104 self.def_constant(ty, SpirvConst::Null)
105 }
106
107 pub fn undef(&self, ty: Word) -> SpirvValue {
108 self.def_constant(ty, SpirvConst::Undef)
109 }
110}
111
112impl ConstCodegenMethods for CodegenCx<'_> {
113 fn const_null(&self, t: Self::Type) -> Self::Value {
114 self.constant_null(t)
115 }
116 fn const_undef(&self, ty: Self::Type) -> Self::Value {
117 self.undef(ty)
118 }
119 fn const_poison(&self, ty: Self::Type) -> Self::Value {
120 self.const_undef(ty)
122 }
123 fn const_int(&self, t: Self::Type, i: i64) -> Self::Value {
124 self.constant_int(t, i as u128)
125 }
126 fn const_uint(&self, t: Self::Type, i: u64) -> Self::Value {
127 self.constant_int(t, i.into())
128 }
129 fn const_uint_big(&self, t: Self::Type, i: u128) -> Self::Value {
130 self.constant_int(t, i)
131 }
132 fn const_bool(&self, val: bool) -> Self::Value {
133 self.constant_bool(DUMMY_SP, val)
134 }
135 fn const_i8(&self, i: i8) -> Self::Value {
136 self.constant_i8(DUMMY_SP, i)
137 }
138 fn const_i16(&self, i: i16) -> Self::Value {
139 self.constant_i16(DUMMY_SP, i)
140 }
141 fn const_i32(&self, i: i32) -> Self::Value {
142 self.constant_i32(DUMMY_SP, i)
143 }
144 fn const_u32(&self, i: u32) -> Self::Value {
145 self.constant_u32(DUMMY_SP, i)
146 }
147 fn const_u64(&self, i: u64) -> Self::Value {
148 self.constant_u64(DUMMY_SP, i)
149 }
150 fn const_u128(&self, i: u128) -> Self::Value {
151 let ty = SpirvType::Integer(128, false).def(DUMMY_SP, self);
152 self.const_uint_big(ty, i)
153 }
154 fn const_usize(&self, i: u64) -> Self::Value {
155 let ptr_size = self.tcx.data_layout.pointer_size().bits() as u32;
156 let t = SpirvType::Integer(ptr_size, false).def(DUMMY_SP, self);
157 self.constant_int(t, i.into())
158 }
159 fn const_u8(&self, i: u8) -> Self::Value {
160 self.constant_u8(DUMMY_SP, i)
161 }
162 fn const_real(&self, t: Self::Type, val: f64) -> Self::Value {
163 self.constant_float(t, val)
164 }
165
166 fn const_str(&self, s: &str) -> (Self::Value, Self::Value) {
167 let len = s.len();
168 let str_ty = self
169 .layout_of(self.tcx.types.str_)
170 .spirv_type(DUMMY_SP, self);
171 (
172 self.def_constant(
173 self.type_ptr_to(str_ty),
174 SpirvConst::PtrTo {
175 pointee: self
176 .constant_composite(
177 str_ty,
178 s.bytes().map(|b| self.const_u8(b).def_cx(self)),
179 )
180 .def_cx(self),
181 },
182 ),
183 self.const_usize(len as u64),
184 )
185 }
186 fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value {
187 let field_types = elts.iter().map(|f| f.ty).collect::<Vec<_>>();
190 let (field_offsets, size, align) = crate::abi::auto_struct_layout(self, &field_types);
191 let struct_ty = SpirvType::Adt {
192 def_id: None,
193 size,
194 align,
195 field_types: &field_types,
196 field_offsets: &field_offsets,
197 field_names: None,
198 }
199 .def(DUMMY_SP, self);
200 self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)))
201 }
202 fn const_vector(&self, elts: &[Self::Value]) -> Self::Value {
203 let vector_ty = SpirvType::simd_vector(
204 self,
205 DUMMY_SP,
206 self.lookup_type(elts[0].ty),
207 elts.len() as u32,
208 )
209 .def(DUMMY_SP, self);
210 self.constant_composite(vector_ty, elts.iter().map(|elt| elt.def_cx(self)))
211 }
212
213 fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> {
214 self.builder.lookup_const_scalar(v)?.try_into().ok()
215 }
216 fn const_to_opt_u128(&self, v: Self::Value, _sign_ext: bool) -> Option<u128> {
219 self.builder.lookup_const_scalar(v)
220 }
221
222 fn scalar_to_backend(
223 &self,
224 scalar: Scalar,
225 layout: abi::Scalar,
226 ty: Self::Type,
227 ) -> Self::Value {
228 match scalar {
229 Scalar::Int(int) => {
230 assert_eq!(int.size(), layout.primitive().size(self));
231 let data = int.to_uint(int.size());
232
233 if let Primitive::Pointer(_) = layout.primitive() {
234 if data == 0 {
235 self.constant_null(ty)
236 } else {
237 let result = self.undef(ty);
238 self.zombie_no_span(
239 result.def_cx(self),
240 "pointer has non-null integer address",
241 );
242 result
243 }
244 } else {
245 self.def_constant(ty, SpirvConst::Scalar(data))
246 }
247 }
248 Scalar::Ptr(ptr, _) => {
249 let (prov, offset) = ptr.prov_and_relative_offset();
250 let alloc_id = prov.alloc_id();
251 let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) {
252 GlobalAlloc::Memory(alloc) => {
253 let pointee = match self.lookup_type(ty) {
254 SpirvType::Pointer { pointee } => pointee,
255 other => self.tcx.dcx().fatal(format!(
256 "GlobalAlloc::Memory type not implemented: {}",
257 other.debug(ty, self)
258 )),
259 };
260 let init = self
263 .try_read_from_const_alloc(alloc, pointee)
264 .unwrap_or_else(|| self.const_data_from_alloc(alloc));
265 let value = self.static_addr_of(init, alloc.inner().align, None);
266 (value, AddressSpace::ZERO)
267 }
268 GlobalAlloc::Function { instance } => (
269 self.get_fn_addr(instance),
270 self.data_layout().instruction_address_space,
271 ),
272 GlobalAlloc::VTable(vty, dyn_ty) => {
273 let alloc = self
274 .tcx
275 .global_alloc(self.tcx.vtable_allocation((
276 vty,
277 dyn_ty.principal().map(|principal| {
278 self.tcx.instantiate_bound_regions_with_erased(principal)
279 }),
280 )))
281 .unwrap_memory();
282 let pointee = match self.lookup_type(ty) {
283 SpirvType::Pointer { pointee } => pointee,
284 other => self.tcx.dcx().fatal(format!(
285 "GlobalAlloc::VTable type not implemented: {}",
286 other.debug(ty, self)
287 )),
288 };
289 let init = self
292 .try_read_from_const_alloc(alloc, pointee)
293 .unwrap_or_else(|| self.const_data_from_alloc(alloc));
294 let value = self.static_addr_of(init, alloc.inner().align, None);
295 (value, AddressSpace::ZERO)
296 }
297 GlobalAlloc::Static(def_id) => {
298 assert!(self.tcx.is_static(def_id));
299 assert!(!self.tcx.is_thread_local_static(def_id));
300 (self.get_static(def_id), AddressSpace::ZERO)
301 }
302 GlobalAlloc::TypeId { .. } => {
303 return if offset.bytes() == 0 {
304 self.constant_null(ty)
305 } else {
306 let result = self.undef(ty);
307 self.zombie_no_span(
308 result.def_cx(self),
309 "pointer has non-null integer address",
310 );
311 result
312 };
313 }
314 };
315 self.const_bitcast(self.const_ptr_byte_offset(base_addr, offset), ty)
316 }
317 }
318 }
319
320 fn const_data_from_alloc(&self, alloc: ConstAllocation<'_>) -> Self::Value {
325 let alloc = self.tcx.lift(alloc).unwrap();
329
330 let void_type = SpirvType::Void.def(DUMMY_SP, self);
331 self.def_constant(void_type, SpirvConst::ConstDataFromAlloc(alloc))
332 }
333
334 fn const_ptr_byte_offset(&self, val: Self::Value, offset: Size) -> Self::Value {
335 if offset == Size::ZERO {
336 val
337 } else {
338 let result = val;
341 self.zombie_no_span(result.def_cx(self), "const_ptr_byte_offset");
342 result
343 }
344 }
345}
346
347impl<'tcx> CodegenCx<'tcx> {
348 pub fn const_bitcast(&self, val: SpirvValue, ty: Word) -> SpirvValue {
349 if let SpirvValueKind::IllegalConst(_) = val.kind
352 && let Some(SpirvConst::PtrTo { pointee }) = self.builder.lookup_const(val)
353 && let Some(SpirvConst::ConstDataFromAlloc(alloc)) =
354 self.builder.lookup_const_by_id(pointee)
355 && let SpirvType::Pointer { pointee } = self.lookup_type(ty)
356 && let Some(init) = self.try_read_from_const_alloc(alloc, pointee)
357 {
358 return self.static_addr_of(init, alloc.inner().align, None);
359 }
360
361 if val.ty == ty {
362 val
363 } else {
364 let result = val.def_cx(self).with_type(ty);
367 self.zombie_no_span(result.def_cx(self), "const_bitcast");
368 result
369 }
370 }
371
372 pub fn primitive_to_scalar(&self, value: Primitive) -> abi::Scalar {
375 let bits = value.size(self.data_layout()).bits();
376 assert!(bits <= 128);
377 abi::Scalar::Initialized {
378 value,
379 valid_range: abi::WrappingRange {
380 start: 0,
381 end: (!0 >> (128 - bits)),
382 },
383 }
384 }
385
386 pub fn try_read_from_const_alloc(
391 &self,
392 alloc: ConstAllocation<'tcx>,
393 ty: Word,
394 ) -> Option<SpirvValue> {
395 let (result, read_size) = self.read_from_const_alloc_at(alloc, ty, Size::ZERO);
396 (read_size == alloc.inner().size()).then_some(result)
397 }
398
399 #[tracing::instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), offset))]
404 fn read_from_const_alloc_at(
405 &self,
406 alloc: ConstAllocation<'tcx>,
407 ty: Word,
408 offset: Size,
409 ) -> (SpirvValue, Size) {
410 let ty_def = self.lookup_type(ty);
411 match ty_def {
412 SpirvType::Bool
413 | SpirvType::Integer(..)
414 | SpirvType::Float(_)
415 | SpirvType::Pointer { .. } => {
416 let size = ty_def.sizeof(self).unwrap();
417 let primitive = match ty_def {
418 SpirvType::Bool => Primitive::Int(Integer::fit_unsigned(0), false),
419 SpirvType::Integer(int_size, int_signedness) => Primitive::Int(
420 match int_size {
421 8 => Integer::I8,
422 16 => Integer::I16,
423 32 => Integer::I32,
424 64 => Integer::I64,
425 128 => Integer::I128,
426 other => {
427 self.tcx
428 .dcx()
429 .fatal(format!("invalid size for integer: {other}"));
430 }
431 },
432 int_signedness,
433 ),
434 SpirvType::Float(float_size) => Primitive::Float(match float_size {
435 16 => Float::F16,
436 32 => Float::F32,
437 64 => Float::F64,
438 128 => Float::F128,
439 other => {
440 self.tcx
441 .dcx()
442 .fatal(format!("invalid size for float: {other}"));
443 }
444 }),
445 SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::ZERO),
446 _ => unreachable!(),
447 };
448
449 let range = alloc_range(offset, size);
450 let read_provenance = matches!(primitive, Primitive::Pointer(_));
451
452 let mut primitive = primitive;
453 let mut read_result = alloc.inner().read_scalar(self, range, read_provenance);
454
455 if read_result.is_err()
459 && !read_provenance
460 && let read_ptr_result @ Ok(Scalar::Ptr(ptr, _)) = alloc
461 .inner()
462 .read_scalar(self, range, true)
463 {
464 let (prov, _offset) = ptr.prov_and_relative_offset();
465 primitive = Primitive::Pointer(
466 self.tcx.global_alloc(prov.alloc_id()).address_space(self),
467 );
468 read_result = read_ptr_result;
469 }
470
471 let scalar_or_zombie = match read_result {
472 Ok(scalar) => {
473 Ok(self.scalar_to_backend(scalar, self.primitive_to_scalar(primitive), ty))
474 }
475
476 Err(err) => match err {
479 AllocError::InvalidUninitBytes(_) => {
483 let uninit_range = alloc
484 .inner()
485 .init_mask()
486 .is_range_initialized(range)
487 .unwrap_err();
488 let uninit_size = {
489 let [start, end] = [uninit_range.start, uninit_range.end()]
490 .map(|x| x.clamp(range.start, range.end()));
491 end - start
492 };
493 if uninit_size == size {
494 Ok(self.undef(ty))
495 } else {
496 Err(format!(
497 "overlaps {} uninitialized bytes",
498 uninit_size.bytes()
499 ))
500 }
501 }
502 AllocError::ReadPointerAsInt(_) => Err("overlaps pointer bytes".into()),
503 AllocError::ReadPartialPointer(_) => {
504 Err("partially overlaps another pointer".into())
505 }
506
507 AllocError::ScalarSizeMismatch(_) => {
510 Err(format!("unrecognized `AllocError::{err:?}`"))
511 }
512 },
513 };
514 let result = scalar_or_zombie.unwrap_or_else(|reason| {
515 let result = self.undef(ty);
516 self.zombie_no_span(
517 result.def_cx(self),
518 &format!("unsupported `{}` constant: {reason}", self.debug_type(ty),),
519 );
520 result
521 });
522 (result, size)
523 }
524 SpirvType::Adt {
525 field_types,
526 field_offsets,
527 ..
528 } => {
529 let mut tail_read_range = ..Size::ZERO;
532 let result = self.constant_composite(
533 ty,
534 field_types
535 .iter()
536 .zip_eq(field_offsets.iter())
537 .map(|(&f_ty, &f_offset)| {
538 let (f, f_size) =
539 self.read_from_const_alloc_at(alloc, f_ty, offset + f_offset);
540 tail_read_range.end =
541 tail_read_range.end.max(offset + f_offset + f_size);
542 f.def_cx(self)
543 }),
544 );
545
546 let ty_size = ty_def.sizeof(self);
547
548 if let Some(ty_size) = ty_size
550 && let Some(tail_gap) = (ty_size.bytes())
551 .checked_sub(tail_read_range.end.align_to(ty_def.alignof(self)).bytes())
552 && tail_gap > 0
553 {
554 self.zombie_no_span(
555 result.def_cx(self),
556 &format!(
557 "undersized `{}` constant (at least {tail_gap} bytes may be missing)",
558 self.debug_type(ty)
559 ),
560 );
561 }
562
563 (result, ty_size.unwrap_or(tail_read_range.end))
564 }
565 SpirvType::Vector { element, .. }
566 | SpirvType::Matrix { element, .. }
567 | SpirvType::Array { element, .. }
568 | SpirvType::RuntimeArray { element } => {
569 let stride = self.lookup_type(element).sizeof(self).unwrap();
570
571 let count = match ty_def {
572 SpirvType::Vector { count, .. } | SpirvType::Matrix { count, .. } => {
573 u64::from(count)
574 }
575 SpirvType::Array { count, .. } => {
576 u64::try_from(self.builder.lookup_const_scalar(count).unwrap()).unwrap()
577 }
578 SpirvType::RuntimeArray { .. } => {
579 (alloc.inner().size() - offset).bytes() / stride.bytes()
580 }
581 _ => unreachable!(),
582 };
583
584 let result = self.constant_composite(
585 ty,
586 (0..count).map(|i| {
587 let (e, e_size) =
588 self.read_from_const_alloc_at(alloc, element, offset + i * stride);
589 assert_eq!(e_size, stride);
590 e.def_cx(self)
591 }),
592 );
593
594 let read_size = (count * stride).align_to(ty_def.alignof(self));
598
599 if let Some(ty_size) = ty_def.sizeof(self) {
600 assert_eq!(read_size, ty_size);
601 }
602
603 if let SpirvType::RuntimeArray { .. } = ty_def {
604 self.zombie_no_span(
609 result.def_cx(self),
610 &format!("unsupported unsized `{}` constant", self.debug_type(ty)),
611 );
612 }
613
614 (result, read_size)
615 }
616
617 SpirvType::Void
618 | SpirvType::Function { .. }
619 | SpirvType::Image { .. }
620 | SpirvType::Sampler
621 | SpirvType::SampledImage { .. }
622 | SpirvType::InterfaceBlock { .. }
623 | SpirvType::AccelerationStructureKhr
624 | SpirvType::RayQueryKhr => {
625 let result = self.undef(ty);
626 self.zombie_no_span(
627 result.def_cx(self),
628 &format!(
629 "cannot reinterpret Rust constant data as a `{}` value",
630 self.debug_type(ty)
631 ),
632 );
633 (result, ty_def.sizeof(self).unwrap_or(Size::ZERO))
634 }
635 }
636 }
637}