1use super::layout::*;
5
6use crate::func_at::FuncAtMut;
7use crate::qptr::{QPtrAttr, QPtrOp, shapes};
8use crate::transform::{InnerInPlaceTransform, Transformed, Transformer};
9use crate::{
10 AddrSpace, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode,
11 ControlNodeKind, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, Diag,
12 FuncDecl, GlobalVarDecl, OrdAssertEq, Type, TypeKind, TypeOrConst, Value, spv,
13};
14use smallvec::SmallVec;
15use std::cell::Cell;
16use std::num::NonZeroU32;
17use std::rc::Rc;
18
19struct LowerError(Diag);
20
21pub struct LowerFromSpvPtrs<'a> {
25 cx: Rc<Context>,
26 wk: &'static spv::spec::WellKnown,
27 layout_cache: LayoutCache<'a>,
28
29 cached_qptr_type: Cell<Option<Type>>,
30}
31
32impl<'a> LowerFromSpvPtrs<'a> {
33 pub fn new(cx: Rc<Context>, layout_config: &'a LayoutConfig) -> Self {
34 Self {
35 cx: cx.clone(),
36 wk: &spv::spec::Spec::get().well_known,
37 layout_cache: LayoutCache::new(cx, layout_config),
38 cached_qptr_type: Default::default(),
39 }
40 }
41
42 pub fn lower_global_var(&self, global_var_decl: &mut GlobalVarDecl) {
43 let wk = self.wk;
44
45 let (_, pointee_type) = self.as_spv_ptr_type(global_var_decl.type_of_ptr_to).unwrap();
46 let handle_layout_to_handle = |handle_layout: HandleLayout| match handle_layout {
47 shapes::Handle::Opaque(ty) => shapes::Handle::Opaque(ty),
48 shapes::Handle::Buffer(addr_space, buf) => {
49 shapes::Handle::Buffer(addr_space, buf.mem_layout)
50 }
51 };
52 let mut shape_result = self.layout_of(pointee_type).and_then(|layout| {
53 Ok(match layout {
54 TypeLayout::Handle(handle) => shapes::GlobalVarShape::Handles {
55 handle: handle_layout_to_handle(handle),
56 fixed_count: Some(NonZeroU32::new(1).unwrap()),
57 },
58 TypeLayout::HandleArray(handle, fixed_count) => shapes::GlobalVarShape::Handles {
59 handle: handle_layout_to_handle(handle),
60 fixed_count,
61 },
62 TypeLayout::Concrete(concrete) => {
63 if concrete.mem_layout.dyn_unit_stride.is_some() {
64 return Err(LowerError(Diag::err([
65 "global variable cannot have dynamically sized type `".into(),
66 pointee_type.into(),
67 "`".into(),
68 ])));
69 }
70 match global_var_decl.addr_space {
71 AddrSpace::SpvStorageClass(sc)
75 if [
76 wk.Input,
77 wk.Output,
78 wk.IncomingRayPayloadKHR,
79 wk.IncomingCallableDataKHR,
80 wk.HitAttributeKHR,
81 wk.RayPayloadKHR,
82 wk.CallableDataKHR,
83 ]
84 .contains(&sc) =>
85 {
86 shapes::GlobalVarShape::TypedInterface(pointee_type)
87 }
88
89 _ => shapes::GlobalVarShape::UntypedData(concrete.mem_layout.fixed_base),
90 }
91 }
92 })
93 });
94 if let Ok(shapes::GlobalVarShape::Handles { handle, .. }) = &mut shape_result {
95 match handle {
96 shapes::Handle::Opaque(_) => {
97 if global_var_decl.addr_space != AddrSpace::SpvStorageClass(wk.UniformConstant)
98 {
99 shape_result = Err(LowerError(Diag::bug([
100 "opaque Handles require UniformConstant".into(),
101 ])));
102 }
103 }
104 shapes::Handle::Buffer(buf_addr_space, _) => {
116 assert!(*buf_addr_space == AddrSpace::Handles);
118 *buf_addr_space = global_var_decl.addr_space;
119 }
120 }
121 if shape_result.is_ok() {
122 global_var_decl.addr_space = AddrSpace::Handles;
123 }
124 }
125 match shape_result {
126 Ok(shape) => {
127 global_var_decl.shape = Some(shape);
128
129 EraseSpvPtrs { lowerer: self }.in_place_transform_global_var_decl(global_var_decl);
132 }
133 Err(LowerError(e)) => {
134 global_var_decl.attrs.push_diag(&self.cx, e);
135 }
136 }
137 }
138
139 pub fn lower_func(&self, func_decl: &mut FuncDecl) {
140 LowerFromSpvPtrInstsInFunc { lowerer: self }.in_place_transform_func_decl(func_decl);
145 EraseSpvPtrs { lowerer: self }.in_place_transform_func_decl(func_decl);
146 }
147
148 fn as_spv_ptr_type(&self, ty: Type) -> Option<(AddrSpace, Type)> {
156 match &self.cx[ty].kind {
157 TypeKind::SpvInst { spv_inst, type_and_const_inputs }
158 if spv_inst.opcode == self.wk.OpTypePointer =>
159 {
160 let sc = match spv_inst.imms[..] {
161 [spv::Imm::Short(_, sc)] => sc,
162 _ => unreachable!(),
163 };
164 let pointee = match type_and_const_inputs[..] {
165 [TypeOrConst::Type(elem_type)] => elem_type,
166 _ => unreachable!(),
167 };
168 Some((AddrSpace::SpvStorageClass(sc), pointee))
169 }
170 _ => None,
171 }
172 }
173
174 fn const_as_u32(&self, ct: Const) -> Option<u32> {
176 if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind {
177 let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs;
178 if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 {
179 match spv_inst.imms[..] {
180 [spv::Imm::Short(_, x)] => return Some(x),
181 _ => unreachable!(),
182 }
183 }
184 }
185 None
186 }
187
188 fn qptr_type(&self) -> Type {
190 if let Some(cached) = self.cached_qptr_type.get() {
191 return cached;
192 }
193 let ty = self.cx.intern(TypeKind::QPtr);
194 self.cached_qptr_type.set(Some(ty));
195 ty
196 }
197
198 fn layout_of(&self, ty: Type) -> Result<TypeLayout, LowerError> {
200 self.layout_cache.layout_of(ty).map_err(|LayoutError(err)| LowerError(err))
201 }
202}
203
204struct EraseSpvPtrs<'a> {
205 lowerer: &'a LowerFromSpvPtrs<'a>,
206}
207
208impl Transformer for EraseSpvPtrs<'_> {
209 fn transform_type_use(&mut self, ty: Type) -> Transformed<Type> {
212 if self.lowerer.as_spv_ptr_type(ty).is_some() {
214 Transformed::Changed(self.lowerer.qptr_type())
215 } else {
216 Transformed::Unchanged
217 }
218 }
219
220 fn transform_const_use(&mut self, ct: Const) -> Transformed<Const> {
223 let ct_def = &self.lowerer.cx[ct];
225 if let ConstKind::PtrToGlobalVar(_) = ct_def.kind {
226 Transformed::Changed(self.lowerer.cx.intern(ConstDef {
227 attrs: ct_def.attrs,
228 ty: self.lowerer.qptr_type(),
229 kind: ct_def.kind.clone(),
230 }))
231 } else {
232 Transformed::Unchanged
233 }
234 }
235
236 fn transform_data_inst_form_use(
241 &mut self,
242 data_inst_form: DataInstForm,
243 ) -> Transformed<DataInstForm> {
244 self.transform_data_inst_form_def(&self.lowerer.cx[data_inst_form])
246 .map(|data_inst_form_def| self.lowerer.cx.intern(data_inst_form_def))
247 }
248}
249
250struct LowerFromSpvPtrInstsInFunc<'a> {
251 lowerer: &'a LowerFromSpvPtrs<'a>,
252}
253
254struct QPtrChainStep {
258 op: QPtrOp,
259
260 dyn_idx: Option<Value>,
263}
264
265impl QPtrChainStep {
266 fn into_data_inst_kind_and_inputs(
267 self,
268 in_qptr: Value,
269 ) -> (DataInstKind, SmallVec<[Value; 2]>) {
270 let Self { op, dyn_idx } = self;
271 (op.into(), [in_qptr].into_iter().chain(dyn_idx).collect())
272 }
273}
274
275impl LowerFromSpvPtrInstsInFunc<'_> {
276 fn try_lower_access_chain(
277 &self,
278 mut layout: TypeLayout,
279 indices: &[Value],
280 ) -> Result<SmallVec<[QPtrChainStep; 4]>, LowerError> {
281 let is_logical_addressing = true;
283
284 let const_idx_as_i32 = |idx| match idx {
285 Value::Const(idx) => self.lowerer.const_as_u32(idx).map(|idx_u32| idx_u32 as i32),
287 _ => None,
288 };
289
290 let mut steps: SmallVec<[QPtrChainStep; 4]> = SmallVec::new();
291 let mut indices = indices.iter().copied();
292 while indices.len() > 0 {
293 let (mut op, component_layout) = match layout {
294 TypeLayout::Handle(shapes::Handle::Opaque(_)) => {
295 return Err(LowerError(Diag::bug([
296 "opaque handles have no sub-components".into()
297 ])));
298 }
299 TypeLayout::Handle(shapes::Handle::Buffer(_, buffer_data_layout)) => {
300 (QPtrOp::BufferData, TypeLayout::Concrete(buffer_data_layout))
301 }
302 TypeLayout::HandleArray(handle, _) => {
303 (QPtrOp::HandleArrayIndex, TypeLayout::Handle(handle))
304 }
305 TypeLayout::Concrete(concrete) => match &concrete.components {
306 Components::Scalar => {
307 return Err(LowerError(Diag::bug([
308 "scalars have no sub-components".into()
309 ])));
310 }
311 Components::Elements { stride, elem, fixed_len } => (
313 QPtrOp::DynOffset {
314 stride: *stride,
315 index_bounds: fixed_len
319 .filter(|_| is_logical_addressing)
320 .and_then(|len| Some(0..len.get().try_into().ok()?)),
321 },
322 TypeLayout::Concrete(elem.clone()),
323 ),
324 Components::Fields { offsets, layouts } => {
325 let field_idx =
326 const_idx_as_i32(indices.next().unwrap()).ok_or_else(|| {
327 LowerError(Diag::bug(["non-constant field index".into()]))
328 })?;
329 let (field_offset, field_layout) = usize::try_from(field_idx)
330 .ok()
331 .and_then(|field_idx| {
332 Some((*offsets.get(field_idx)?, layouts.get(field_idx)?.clone()))
333 })
334 .ok_or_else(|| {
335 LowerError(Diag::bug([format!(
336 "field {field_idx} out of bounds (expected 0..{})",
337 offsets.len()
338 )
339 .into()]))
340 })?;
341 (
342 QPtrOp::Offset(i32::try_from(field_offset).ok().ok_or_else(|| {
343 LowerError(Diag::bug([format!(
344 "{field_offset} not representable as a positive s32"
345 )
346 .into()]))
347 })?),
348 TypeLayout::Concrete(field_layout),
349 )
350 }
351 },
352 };
353 layout = component_layout;
354
355 let mut dyn_idx = match op {
357 QPtrOp::HandleArrayIndex | QPtrOp::DynOffset { .. } => {
358 Some(indices.next().unwrap())
359 }
360 _ => None,
361 };
362
363 if let QPtrOp::DynOffset { stride, index_bounds } = &op {
365 let const_offset = const_idx_as_i32(dyn_idx.unwrap())
366 .filter(|const_idx| {
367 index_bounds.as_ref().map_or(true, |bounds| bounds.contains(const_idx))
368 })
369 .and_then(|const_idx| i32::try_from(stride.get()).ok()?.checked_mul(const_idx));
370 if let Some(const_offset) = const_offset {
371 op = QPtrOp::Offset(const_offset);
372 dyn_idx = None;
373 }
374 }
375
376 match (steps.last_mut().map(|last_step| &mut last_step.op), &op) {
378 (_, QPtrOp::Offset(0)) => {}
380
381 (Some(QPtrOp::Offset(last_offset)), &QPtrOp::Offset(new_offset)) => {
382 *last_offset = last_offset.checked_add(new_offset).ok_or_else(|| {
383 LowerError(Diag::bug([format!(
384 "offset overflow ({last_offset}+{new_offset})"
385 )
386 .into()]))
387 })?;
388 }
389
390 _ => steps.push(QPtrChainStep { op, dyn_idx }),
391 }
392 }
393 Ok(steps)
394 }
395
396 fn try_lower_data_inst_def(
397 &self,
398 mut func_at_data_inst: FuncAtMut<'_, DataInst>,
399 parent_block: ControlNode,
400 ) -> Result<Transformed<DataInstDef>, LowerError> {
401 let cx = &self.lowerer.cx;
402 let wk = self.lowerer.wk;
403
404 let func_at_data_inst_frozen = func_at_data_inst.reborrow().freeze();
405 let data_inst = func_at_data_inst_frozen.position;
406 let data_inst_def = func_at_data_inst_frozen.def();
407
408 let func = func_at_data_inst_frozen.at(());
410
411 let mut attrs = data_inst_def.attrs;
412 let DataInstFormDef { ref kind, output_type } = cx[data_inst_def.form];
413
414 let spv_inst = match kind {
415 DataInstKind::SpvInst(spv_inst) => spv_inst,
416 _ => return Ok(Transformed::Unchanged),
417 };
418
419 let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable {
420 assert!(data_inst_def.inputs.len() <= 1);
421 let (_, var_data_type) =
422 self.lowerer.as_spv_ptr_type(output_type.unwrap()).ok_or_else(|| {
423 LowerError(Diag::bug(["output type not an `OpTypePointer`".into()]))
424 })?;
425 match self.lowerer.layout_of(var_data_type)? {
426 TypeLayout::Concrete(concrete) if concrete.mem_layout.dyn_unit_stride.is_none() => {
427 (
428 QPtrOp::FuncLocalVar(concrete.mem_layout.fixed_base).into(),
429 data_inst_def.inputs.clone(),
430 )
431 }
432 _ => return Ok(Transformed::Unchanged),
433 }
434 } else if spv_inst.opcode == wk.OpLoad {
435 if !spv_inst.imms.is_empty() {
437 return Ok(Transformed::Unchanged);
438 }
439 assert_eq!(data_inst_def.inputs.len(), 1);
440 (QPtrOp::Load.into(), data_inst_def.inputs.clone())
441 } else if spv_inst.opcode == wk.OpStore {
442 if !spv_inst.imms.is_empty() {
444 return Ok(Transformed::Unchanged);
445 }
446 assert_eq!(data_inst_def.inputs.len(), 2);
447 (QPtrOp::Store.into(), data_inst_def.inputs.clone())
448 } else if spv_inst.opcode == wk.OpArrayLength {
449 let field_idx = match spv_inst.imms[..] {
450 [spv::Imm::Short(_, field_idx)] => field_idx,
451 _ => unreachable!(),
452 };
453 assert_eq!(data_inst_def.inputs.len(), 1);
454 let ptr = data_inst_def.inputs[0];
455 let (_, pointee_type) =
456 self.lowerer.as_spv_ptr_type(func.at(ptr).type_of(cx)).ok_or_else(|| {
457 LowerError(Diag::bug(["pointer input not an `OpTypePointer`".into()]))
458 })?;
459
460 let buf_data_layout = match self.lowerer.layout_of(pointee_type)? {
461 TypeLayout::Handle(shapes::Handle::Buffer(_, buf)) => buf,
462 _ => return Err(LowerError(Diag::bug(["non-Buffer pointee".into()]))),
463 };
464
465 let (field_offset, field_layout) = match &buf_data_layout.components {
466 Components::Fields { offsets, layouts } => usize::try_from(field_idx)
467 .ok()
468 .and_then(|field_idx| {
469 Some((*offsets.get(field_idx)?, layouts.get(field_idx)?.clone()))
470 })
471 .ok_or_else(|| {
472 LowerError(Diag::bug([format!(
473 "field {field_idx} out of bounds (expected 0..{})",
474 offsets.len()
475 )
476 .into()]))
477 })?,
478
479 _ => {
480 return Err(LowerError(Diag::bug(
481 ["buffer data not an `OpTypeStruct`".into()],
482 )));
483 }
484 };
485 let array_stride = match field_layout.components {
486 Components::Elements { stride, fixed_len: None, .. } => stride,
487
488 _ => {
489 return Err(LowerError(Diag::bug([format!(
490 "buffer data field #{field_idx} not an `OpTypeRuntimeArray`"
491 )
492 .into()])));
493 }
494 };
495
496 assert_eq!(field_layout.mem_layout.fixed_base.size, 0);
498 assert_eq!(field_layout.mem_layout.dyn_unit_stride, Some(array_stride));
499 assert_eq!(buf_data_layout.mem_layout.fixed_base.size, field_offset);
500 assert_eq!(buf_data_layout.mem_layout.dyn_unit_stride, Some(array_stride));
501
502 (
503 QPtrOp::BufferDynLen {
504 fixed_base_size: field_offset,
505 dyn_unit_stride: array_stride,
506 }
507 .into(),
508 data_inst_def.inputs.clone(),
509 )
510 } else if [
511 wk.OpAccessChain,
512 wk.OpInBoundsAccessChain,
513 wk.OpPtrAccessChain,
514 wk.OpInBoundsPtrAccessChain,
515 ]
516 .contains(&spv_inst.opcode)
517 {
518 let base_ptr = data_inst_def.inputs[0];
520 let (_, base_pointee_type) =
521 self.lowerer.as_spv_ptr_type(func.at(base_ptr).type_of(cx)).ok_or_else(|| {
522 LowerError(Diag::bug(["pointer input not an `OpTypePointer`".into()]))
523 })?;
524
525 let access_chain_base_layout =
528 if [wk.OpPtrAccessChain, wk.OpInBoundsPtrAccessChain].contains(&spv_inst.opcode) {
529 self.lowerer.layout_of(cx.intern(TypeKind::SpvInst {
530 spv_inst: wk.OpTypeRuntimeArray.into(),
531 type_and_const_inputs:
532 [TypeOrConst::Type(base_pointee_type)].into_iter().collect(),
533 }))?
534 } else {
535 self.lowerer.layout_of(base_pointee_type)?
536 };
537
538 let mut steps =
539 self.try_lower_access_chain(access_chain_base_layout, &data_inst_def.inputs[1..])?;
540 let final_step =
543 steps.pop().unwrap_or(QPtrChainStep { op: QPtrOp::Offset(0), dyn_idx: None });
544
545 let mut ptr = base_ptr;
546 for step in steps {
547 let (kind, inputs) = step.into_data_inst_kind_and_inputs(ptr);
548 let step_data_inst = func_at_data_inst.reborrow().data_insts.define(
549 cx,
550 DataInstDef {
551 attrs: Default::default(),
552 form: cx.intern(DataInstFormDef {
553 kind,
554 output_type: Some(self.lowerer.qptr_type()),
555 }),
556 inputs,
557 }
558 .into(),
559 );
560
561 let func = func_at_data_inst.reborrow().at(());
568 match &mut func.control_nodes[parent_block].kind {
569 ControlNodeKind::Block { insts } => {
570 insts.insert_before(step_data_inst, data_inst, func.data_insts);
571 }
572 _ => unreachable!(),
573 }
574
575 ptr = Value::DataInstOutput(step_data_inst);
576 }
577 final_step.into_data_inst_kind_and_inputs(ptr)
578 } else if spv_inst.opcode == wk.OpBitcast {
579 let input = data_inst_def.inputs[0];
580 if self.lowerer.as_spv_ptr_type(func.at(input).type_of(cx)).is_some()
582 && self.lowerer.as_spv_ptr_type(output_type.unwrap()).is_some()
583 {
584 let noop_step = QPtrChainStep { op: QPtrOp::Offset(0), dyn_idx: None };
587
588 attrs = AttrSet::default();
591
592 noop_step.into_data_inst_kind_and_inputs(input)
593 } else {
594 return Ok(Transformed::Unchanged);
595 }
596 } else {
597 return Ok(Transformed::Unchanged);
598 };
599 let (new_kind, new_inputs) = replacement_kind_and_inputs;
601 Ok(Transformed::Changed(DataInstDef {
602 attrs,
603 form: cx.intern(DataInstFormDef { kind: new_kind, output_type }),
607 inputs: new_inputs,
608 }))
609 }
610
611 fn add_fallback_attrs_to_data_inst_def(
612 &self,
613 mut func_at_data_inst: FuncAtMut<'_, DataInst>,
614 extra_error: Option<LowerError>,
615 ) {
616 let cx = &self.lowerer.cx;
617
618 let func_at_data_inst_frozen = func_at_data_inst.reborrow().freeze();
619 let data_inst_def = func_at_data_inst_frozen.def();
620 let data_inst_form_def = &cx[data_inst_def.form];
621
622 let func = func_at_data_inst_frozen.at(());
624
625 match data_inst_form_def.kind {
626 DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return,
628
629 DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {}
630 }
631
632 let mut old_and_new_attrs = None;
633 let get_old_attrs = || AttrSetDef { attrs: cx[data_inst_def.attrs].attrs.clone() };
634
635 for (input_idx, &v) in data_inst_def.inputs.iter().enumerate() {
636 if let Some((_, pointee)) = self.lowerer.as_spv_ptr_type(func.at(v).type_of(cx)) {
637 old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert(
638 QPtrAttr::ToSpvPtrInput {
639 input_idx: input_idx.try_into().unwrap(),
640 pointee: OrdAssertEq(pointee),
641 }
642 .into(),
643 );
644 }
645 }
646 if let Some(output_type) = data_inst_form_def.output_type {
647 if let Some((addr_space, pointee)) = self.lowerer.as_spv_ptr_type(output_type) {
648 old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert(
649 QPtrAttr::FromSpvPtrOutput {
650 addr_space: OrdAssertEq(addr_space),
651 pointee: OrdAssertEq(pointee),
652 }
653 .into(),
654 );
655 }
656 }
657
658 if let Some(LowerError(e)) = extra_error {
659 old_and_new_attrs.get_or_insert_with(get_old_attrs).push_diag(e);
660 }
661
662 if let Some(attrs) = old_and_new_attrs {
663 func_at_data_inst.def().attrs = cx.intern(attrs);
664 }
665 }
666}
667
668impl Transformer for LowerFromSpvPtrInstsInFunc<'_> {
669 fn in_place_transform_control_node_def(
678 &mut self,
679 mut func_at_control_node: FuncAtMut<'_, ControlNode>,
680 ) {
681 func_at_control_node.reborrow().inner_in_place_transform_with(self);
682
683 let control_node = func_at_control_node.position;
684 if let ControlNodeKind::Block { insts } = func_at_control_node.reborrow().def().kind {
685 let mut func_at_inst_iter = func_at_control_node.reborrow().at(insts).into_iter();
686 while let Some(mut func_at_inst) = func_at_inst_iter.next() {
687 match self.try_lower_data_inst_def(func_at_inst.reborrow(), control_node) {
688 Ok(Transformed::Changed(new_def)) => {
689 *func_at_inst.def() = new_def;
690 }
691 result @ (Ok(Transformed::Unchanged) | Err(_)) => {
692 self.add_fallback_attrs_to_data_inst_def(func_at_inst, result.err());
693 }
694 }
695 }
696 }
697 }
698}