spirt/spv/
lift.rs

1//! SPIR-T to SPIR-V lifting.
2
3use crate::func_at::FuncAt;
4use crate::spv::{self, spec};
5use crate::visit::{InnerVisit, Visitor};
6use crate::{
7    AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind,
8    ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef,
9    DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, ExportKey, Exportee, Func,
10    FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module,
11    ModuleDebugInfo, ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value,
12    cfg,
13};
14use rustc_hash::FxHashMap;
15use smallvec::SmallVec;
16use std::borrow::Cow;
17use std::collections::{BTreeMap, BTreeSet};
18use std::num::NonZeroU32;
19use std::path::Path;
20use std::{io, iter, mem, slice};
21
22impl spv::Dialect {
23    fn capability_insts(&self) -> impl Iterator<Item = spv::InstWithIds> + '_ {
24        let wk = &spec::Spec::get().well_known;
25        self.capabilities.iter().map(move |&cap| spv::InstWithIds {
26            without_ids: spv::Inst {
27                opcode: wk.OpCapability,
28                imms: iter::once(spv::Imm::Short(wk.Capability, cap)).collect(),
29            },
30            result_type_id: None,
31            result_id: None,
32            ids: [].into_iter().collect(),
33        })
34    }
35
36    pub fn extension_insts(&self) -> impl Iterator<Item = spv::InstWithIds> + '_ {
37        let wk = &spec::Spec::get().well_known;
38        self.extensions.iter().map(move |ext| spv::InstWithIds {
39            without_ids: spv::Inst {
40                opcode: wk.OpExtension,
41                imms: spv::encode_literal_string(ext).collect(),
42            },
43            result_type_id: None,
44            result_id: None,
45            ids: [].into_iter().collect(),
46        })
47    }
48}
49
50impl spv::ModuleDebugInfo {
51    fn source_extension_insts(&self) -> impl Iterator<Item = spv::InstWithIds> + '_ {
52        let wk = &spec::Spec::get().well_known;
53        self.source_extensions.iter().map(move |ext| spv::InstWithIds {
54            without_ids: spv::Inst {
55                opcode: wk.OpSourceExtension,
56                imms: spv::encode_literal_string(ext).collect(),
57            },
58            result_type_id: None,
59            result_id: None,
60            ids: [].into_iter().collect(),
61        })
62    }
63
64    fn module_processed_insts(&self) -> impl Iterator<Item = spv::InstWithIds> + '_ {
65        let wk = &spec::Spec::get().well_known;
66        self.module_processes.iter().map(move |proc| spv::InstWithIds {
67            without_ids: spv::Inst {
68                opcode: wk.OpModuleProcessed,
69                imms: spv::encode_literal_string(proc).collect(),
70            },
71            result_type_id: None,
72            result_id: None,
73            ids: [].into_iter().collect(),
74        })
75    }
76}
77
78impl FuncDecl {
79    fn spv_func_type(&self, cx: &Context) -> Type {
80        let wk = &spec::Spec::get().well_known;
81
82        cx.intern(TypeDef {
83            attrs: AttrSet::default(),
84            kind: TypeKind::SpvInst {
85                spv_inst: wk.OpTypeFunction.into(),
86                type_and_const_inputs: iter::once(self.ret_type)
87                    .chain(self.params.iter().map(|param| param.ty))
88                    .map(TypeOrConst::Type)
89                    .collect(),
90            },
91        })
92    }
93}
94
95struct NeedsIdsCollector<'a> {
96    cx: &'a Context,
97    module: &'a Module,
98
99    ext_inst_imports: BTreeSet<&'a str>,
100    debug_strings: BTreeSet<&'a str>,
101
102    globals: FxIndexSet<Global>,
103    data_inst_forms_seen: FxIndexSet<DataInstForm>,
104    global_vars_seen: FxIndexSet<GlobalVar>,
105    funcs: FxIndexSet<Func>,
106}
107
108#[derive(Copy, Clone, PartialEq, Eq, Hash)]
109enum Global {
110    Type(Type),
111    Const(Const),
112}
113
114impl Visitor<'_> for NeedsIdsCollector<'_> {
115    fn visit_attr_set_use(&mut self, attrs: AttrSet) {
116        self.visit_attr_set_def(&self.cx[attrs]);
117    }
118    fn visit_type_use(&mut self, ty: Type) {
119        let global = Global::Type(ty);
120        if self.globals.contains(&global) {
121            return;
122        }
123        let ty_def = &self.cx[ty];
124        match ty_def.kind {
125            // FIXME(eddyb) this should be a proper `Result`-based error instead,
126            // and/or `spv::lift` should mutate the module for legalization.
127            TypeKind::QPtr => {
128                unreachable!("`TypeKind::QPtr` should be legalized away before lifting");
129            }
130
131            TypeKind::SpvInst { .. } => {}
132            TypeKind::SpvStringLiteralForExtInst => {
133                unreachable!(
134                    "`TypeKind::SpvStringLiteralForExtInst` should not be used \
135                     as a type outside of `ConstKind::SpvStringLiteralForExtInst`"
136                );
137            }
138        }
139        self.visit_type_def(ty_def);
140        self.globals.insert(global);
141    }
142    fn visit_const_use(&mut self, ct: Const) {
143        let global = Global::Const(ct);
144        if self.globals.contains(&global) {
145            return;
146        }
147        let ct_def = &self.cx[ct];
148        match ct_def.kind {
149            ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => {
150                self.visit_const_def(ct_def);
151                self.globals.insert(global);
152            }
153
154            // HACK(eddyb) because this is an `OpString` and needs to go earlier
155            // in the module than any `OpConstant*`, it needs to be special-cased,
156            // without visiting its type, or an entry in `self.globals`.
157            ConstKind::SpvStringLiteralForExtInst(s) => {
158                let ConstDef { attrs, ty, kind: _ } = ct_def;
159
160                assert!(*attrs == AttrSet::default());
161                assert!(
162                    self.cx[*ty]
163                        == TypeDef {
164                            attrs: AttrSet::default(),
165                            kind: TypeKind::SpvStringLiteralForExtInst,
166                        }
167                );
168
169                self.debug_strings.insert(&self.cx[s]);
170            }
171        }
172    }
173    fn visit_data_inst_form_use(&mut self, data_inst_form: DataInstForm) {
174        if self.data_inst_forms_seen.insert(data_inst_form) {
175            self.visit_data_inst_form_def(&self.cx[data_inst_form]);
176        }
177    }
178
179    fn visit_global_var_use(&mut self, gv: GlobalVar) {
180        if self.global_vars_seen.insert(gv) {
181            self.visit_global_var_decl(&self.module.global_vars[gv]);
182        }
183    }
184    fn visit_func_use(&mut self, func: Func) {
185        if self.funcs.contains(&func) {
186            return;
187        }
188        // NOTE(eddyb) inserting first results in a different function ordering
189        // in the resulting module, but the order doesn't matter, and we need
190        // to avoid infinite recursion for recursive functions.
191        self.funcs.insert(func);
192
193        let func_decl = &self.module.funcs[func];
194        // FIXME(eddyb) should this be cached in `self.funcs`?
195        self.visit_type_use(func_decl.spv_func_type(self.cx));
196        self.visit_func_decl(func_decl);
197    }
198
199    fn visit_spv_module_debug_info(&mut self, debug_info: &spv::ModuleDebugInfo) {
200        for sources in debug_info.source_languages.values() {
201            // The file operand of `OpSource` has to point to an `OpString`.
202            self.debug_strings.extend(sources.file_contents.keys().copied().map(|s| &self.cx[s]));
203        }
204    }
205    fn visit_attr(&mut self, attr: &Attr) {
206        match *attr {
207            Attr::Diagnostics(_)
208            | Attr::QPtr(_)
209            | Attr::SpvAnnotation { .. }
210            | Attr::SpvBitflagsOperand(_) => {}
211            Attr::SpvDebugLine { file_path, .. } => {
212                self.debug_strings.insert(&self.cx[file_path.0]);
213            }
214        }
215        attr.inner_visit_with(self);
216    }
217
218    fn visit_data_inst_form_def(&mut self, data_inst_form_def: &DataInstFormDef) {
219        #[allow(clippy::match_same_arms)]
220        match data_inst_form_def.kind {
221            // FIXME(eddyb) this should be a proper `Result`-based error instead,
222            // and/or `spv::lift` should mutate the module for legalization.
223            DataInstKind::QPtr(_) => {
224                unreachable!("`DataInstKind::QPtr` should be legalized away before lifting");
225            }
226
227            DataInstKind::FuncCall(_) => {}
228
229            DataInstKind::SpvInst(_) => {}
230            DataInstKind::SpvExtInst { ext_set, .. } => {
231                self.ext_inst_imports.insert(&self.cx[ext_set]);
232            }
233        }
234        data_inst_form_def.inner_visit_with(self);
235    }
236}
237
238struct AllocatedIds<'a> {
239    ext_inst_imports: BTreeMap<&'a str, spv::Id>,
240    debug_strings: BTreeMap<&'a str, spv::Id>,
241
242    // FIXME(eddyb) use `EntityOrientedDenseMap` here.
243    globals: FxIndexMap<Global, spv::Id>,
244    // FIXME(eddyb) use `EntityOrientedDenseMap` here.
245    funcs: FxIndexMap<Func, FuncLifting<'a>>,
246}
247
248// FIXME(eddyb) should this use ID ranges instead of `SmallVec<[spv::Id; 4]>`?
249struct FuncLifting<'a> {
250    func_id: spv::Id,
251    param_ids: SmallVec<[spv::Id; 4]>,
252
253    // FIXME(eddyb) use `EntityOrientedDenseMap` here.
254    region_inputs_source: FxHashMap<ControlRegion, RegionInputsSource>,
255    // FIXME(eddyb) use `EntityOrientedDenseMap` here.
256    data_inst_output_ids: FxHashMap<DataInst, spv::Id>,
257
258    label_ids: FxHashMap<CfgPoint, spv::Id>,
259    blocks: FxIndexMap<CfgPoint, BlockLifting<'a>>,
260}
261
262/// What determines the values for [`Value::ControlRegionInput`]s, for a specific
263/// region (effectively the subset of "region parents" that support inputs).
264///
265/// Note that this is not used when a [`cfg::ControlInst`] has `target_inputs`,
266/// and the target [`ControlRegion`] itself has phis for its `inputs`.
267enum RegionInputsSource {
268    FuncParams,
269    LoopHeaderPhis(ControlNode),
270}
271
272/// Any of the possible points in structured or unstructured SPIR-T control-flow,
273/// that may require a separate SPIR-V basic block.
274#[derive(Copy, Clone, PartialEq, Eq, Hash)]
275enum CfgPoint {
276    RegionEntry(ControlRegion),
277    RegionExit(ControlRegion),
278
279    ControlNodeEntry(ControlNode),
280    ControlNodeExit(ControlNode),
281}
282
283struct BlockLifting<'a> {
284    phis: SmallVec<[Phi; 2]>,
285    insts: SmallVec<[EntityList<DataInst>; 1]>,
286    terminator: Terminator<'a>,
287}
288
289struct Phi {
290    attrs: AttrSet,
291    ty: Type,
292
293    result_id: spv::Id,
294    cases: FxIndexMap<CfgPoint, Value>,
295
296    // HACK(eddyb) used for `Loop` `initial_inputs`, to indicate that any edge
297    // to the `Loop` (other than the backedge, which is already in `cases`)
298    // should automatically get an entry into `cases`, with this value.
299    default_value: Option<Value>,
300}
301
302/// Similar to [`cfg::ControlInst`], except:
303/// * `targets` use [`CfgPoint`]s instead of [`ControlRegion`]s, to be able to
304///   reach any of the SPIR-V blocks being created during lifting
305/// * φ ("phi") values can be provided for targets regardless of "which side" of
306///   the structured control-flow they are for ("region input" vs "node output")
307/// * optional `merge` (for `OpSelectionMerge`/`OpLoopMerge`)
308/// * existing data is borrowed (from the [`FuncDefBody`](crate::FuncDefBody)),
309///   wherever possible
310struct Terminator<'a> {
311    attrs: AttrSet,
312
313    kind: Cow<'a, cfg::ControlInstKind>,
314
315    // FIXME(eddyb) use `Cow` or something, but ideally the "owned" case always
316    // has at most one input, so allocating a whole `Vec` for that seems unwise.
317    inputs: SmallVec<[Value; 2]>,
318
319    // FIXME(eddyb) change the inline size of this to fit most instructions.
320    targets: SmallVec<[CfgPoint; 4]>,
321
322    target_phi_values: FxIndexMap<CfgPoint, &'a [Value]>,
323
324    merge: Option<Merge<CfgPoint>>,
325}
326
327#[derive(Copy, Clone)]
328enum Merge<L> {
329    Selection(L),
330
331    Loop {
332        /// The label just after the whole loop, i.e. the `break` target.
333        loop_merge: L,
334
335        /// A label that the back-edge block post-dominates, i.e. some point in
336        /// the loop body where looping around is inevitable (modulo `break`ing
337        /// out of the loop through a `do`-`while`-style conditional back-edge).
338        ///
339        /// SPIR-V calls this "the `continue` target", but unlike other aspects
340        /// of SPIR-V "structured control-flow", there can be multiple valid
341        /// choices (any that fit the post-dominator/"inevitability" definition).
342        //
343        // FIXME(eddyb) https://github.com/EmbarkStudios/spirt/pull/10 tried to
344        // set this to the loop body entry, but that may not be valid if the loop
345        // body actually diverges, because then the loop body exit will still be
346        // post-dominating the back-edge *but* the loop body itself wouldn't have
347        // any relationship between its entry and its *unreachable* exit.
348        loop_continue: L,
349    },
350}
351
352impl<'a> NeedsIdsCollector<'a> {
353    fn alloc_ids<E>(
354        self,
355        mut alloc_id: impl FnMut() -> Result<spv::Id, E>,
356    ) -> Result<AllocatedIds<'a>, E> {
357        let Self {
358            cx,
359            module,
360            ext_inst_imports,
361            debug_strings,
362            globals,
363            data_inst_forms_seen: _,
364            global_vars_seen: _,
365            funcs,
366        } = self;
367
368        Ok(AllocatedIds {
369            ext_inst_imports: ext_inst_imports
370                .into_iter()
371                .map(|name| Ok((name, alloc_id()?)))
372                .collect::<Result<_, _>>()?,
373            debug_strings: debug_strings
374                .into_iter()
375                .map(|s| Ok((s, alloc_id()?)))
376                .collect::<Result<_, _>>()?,
377            globals: globals.into_iter().map(|g| Ok((g, alloc_id()?))).collect::<Result<_, _>>()?,
378            funcs: funcs
379                .into_iter()
380                .map(|func| {
381                    Ok((func, FuncLifting::from_func_decl(cx, &module.funcs[func], &mut alloc_id)?))
382                })
383                .collect::<Result<_, _>>()?,
384        })
385    }
386}
387
388/// Helper type for deep traversal of the CFG (as a graph of [`CfgPoint`]s), which
389/// tracks the necessary context for navigating a [`ControlRegion`]/[`ControlNode`].
390#[derive(Copy, Clone)]
391struct CfgCursor<'p, P = CfgPoint> {
392    point: P,
393    parent: Option<&'p CfgCursor<'p, ControlParent>>,
394}
395
396enum ControlParent {
397    Region(ControlRegion),
398    ControlNode(ControlNode),
399}
400
401impl<'p> FuncAt<'_, CfgCursor<'p>> {
402    /// Return the next [`CfgPoint`] (wrapped in [`CfgCursor`]) in a linear
403    /// chain within structured control-flow (i.e. no branching to child regions).
404    fn unique_successor(self) -> Option<CfgCursor<'p>> {
405        let cursor = self.position;
406        match cursor.point {
407            // Entering a `ControlRegion` enters its first `ControlNode` child,
408            // or exits the region right away (if it has no children).
409            CfgPoint::RegionEntry(region) => Some(CfgCursor {
410                point: match self.at(region).def().children.iter().first {
411                    Some(first_child) => CfgPoint::ControlNodeEntry(first_child),
412                    None => CfgPoint::RegionExit(region),
413                },
414                parent: cursor.parent,
415            }),
416
417            // Exiting a `ControlRegion` exits its parent `ControlNode`.
418            CfgPoint::RegionExit(_) => cursor.parent.map(|parent| match parent.point {
419                ControlParent::Region(_) => unreachable!(),
420                ControlParent::ControlNode(parent_control_node) => CfgCursor {
421                    point: CfgPoint::ControlNodeExit(parent_control_node),
422                    parent: parent.parent,
423                },
424            }),
425
426            // Entering a `ControlNode` depends entirely on the `ControlNodeKind`.
427            CfgPoint::ControlNodeEntry(control_node) => match self.at(control_node).def().kind {
428                ControlNodeKind::Block { .. } => Some(CfgCursor {
429                    point: CfgPoint::ControlNodeExit(control_node),
430                    parent: cursor.parent,
431                }),
432
433                ControlNodeKind::Select { .. }
434                | ControlNodeKind::Loop { .. }
435                | ControlNodeKind::ExitInvocation { .. } => None,
436            },
437
438            // Exiting a `ControlNode` chains to a sibling/parent.
439            CfgPoint::ControlNodeExit(control_node) => {
440                Some(match self.control_nodes[control_node].next_in_list() {
441                    // Enter the next sibling in the `ControlRegion`, if one exists.
442                    Some(next_control_node) => CfgCursor {
443                        point: CfgPoint::ControlNodeEntry(next_control_node),
444                        parent: cursor.parent,
445                    },
446
447                    // Exit the parent `ControlRegion`.
448                    None => {
449                        let parent = cursor.parent.unwrap();
450                        match cursor.parent.unwrap().point {
451                            ControlParent::Region(parent_region) => CfgCursor {
452                                point: CfgPoint::RegionExit(parent_region),
453                                parent: parent.parent,
454                            },
455                            ControlParent::ControlNode(_) => unreachable!(),
456                        }
457                    }
458                })
459            }
460        }
461    }
462}
463
464impl FuncAt<'_, ControlRegion> {
465    /// Traverse every [`CfgPoint`] (deeply) contained in this [`ControlRegion`],
466    /// in reverse post-order (RPO), with `f` receiving each [`CfgPoint`]
467    /// in turn (wrapped in [`CfgCursor`], for further traversal flexibility),
468    /// and being able to stop iteration by returning `Err`.
469    ///
470    /// RPO iteration over a CFG provides certain guarantees, most importantly
471    /// that dominators are visited before the entire subgraph they dominate.
472    fn rev_post_order_try_for_each<E>(
473        self,
474        mut f: impl FnMut(CfgCursor<'_>) -> Result<(), E>,
475    ) -> Result<(), E> {
476        self.rev_post_order_try_for_each_inner(&mut f, None)
477    }
478
479    fn rev_post_order_try_for_each_inner<E>(
480        self,
481        f: &mut impl FnMut(CfgCursor<'_>) -> Result<(), E>,
482        parent: Option<&CfgCursor<'_, ControlParent>>,
483    ) -> Result<(), E> {
484        let region = self.position;
485        f(CfgCursor { point: CfgPoint::RegionEntry(region), parent })?;
486        for func_at_control_node in self.at_children() {
487            func_at_control_node.rev_post_order_try_for_each_inner(f, &CfgCursor {
488                point: ControlParent::Region(region),
489                parent,
490            })?;
491        }
492        f(CfgCursor { point: CfgPoint::RegionExit(region), parent })
493    }
494}
495
496impl FuncAt<'_, ControlNode> {
497    fn rev_post_order_try_for_each_inner<E>(
498        self,
499        f: &mut impl FnMut(CfgCursor<'_>) -> Result<(), E>,
500        parent: &CfgCursor<'_, ControlParent>,
501    ) -> Result<(), E> {
502        let child_regions: &[_] = match &self.def().kind {
503            ControlNodeKind::Block { .. } | ControlNodeKind::ExitInvocation { .. } => &[],
504            ControlNodeKind::Select { cases, .. } => cases,
505            ControlNodeKind::Loop { body, .. } => slice::from_ref(body),
506        };
507
508        let control_node = self.position;
509        let parent = Some(parent);
510        f(CfgCursor { point: CfgPoint::ControlNodeEntry(control_node), parent })?;
511        for &region in child_regions {
512            self.at(region).rev_post_order_try_for_each_inner(
513                f,
514                Some(&CfgCursor { point: ControlParent::ControlNode(control_node), parent }),
515            )?;
516        }
517        f(CfgCursor { point: CfgPoint::ControlNodeExit(control_node), parent })
518    }
519}
520
521impl<'a> FuncLifting<'a> {
522    fn from_func_decl<E>(
523        cx: &Context,
524        func_decl: &'a FuncDecl,
525        mut alloc_id: impl FnMut() -> Result<spv::Id, E>,
526    ) -> Result<Self, E> {
527        let wk = &spec::Spec::get().well_known;
528
529        let func_id = alloc_id()?;
530        let param_ids = func_decl.params.iter().map(|_| alloc_id()).collect::<Result<_, _>>()?;
531
532        let func_def_body = match &func_decl.def {
533            DeclDef::Imported(_) => {
534                return Ok(Self {
535                    func_id,
536                    param_ids,
537                    region_inputs_source: Default::default(),
538                    data_inst_output_ids: Default::default(),
539                    label_ids: Default::default(),
540                    blocks: Default::default(),
541                });
542            }
543            DeclDef::Present(def) => def,
544        };
545
546        let mut region_inputs_source = FxHashMap::default();
547        region_inputs_source.insert(func_def_body.body, RegionInputsSource::FuncParams);
548
549        // Create a SPIR-V block for every CFG point needing one.
550        let mut blocks = FxIndexMap::default();
551        let mut visit_cfg_point = |point_cursor: CfgCursor<'_>| {
552            let point = point_cursor.point;
553
554            let phis = match point {
555                CfgPoint::RegionEntry(region) => {
556                    if region_inputs_source.contains_key(&region) {
557                        // Region inputs handled by the parent of the region.
558                        SmallVec::new()
559                    } else {
560                        func_def_body
561                            .at(region)
562                            .def()
563                            .inputs
564                            .iter()
565                            .map(|&ControlRegionInputDecl { attrs, ty }| {
566                                Ok(Phi {
567                                    attrs,
568                                    ty,
569
570                                    result_id: alloc_id()?,
571                                    cases: FxIndexMap::default(),
572                                    default_value: None,
573                                })
574                            })
575                            .collect::<Result<_, _>>()?
576                    }
577                }
578                CfgPoint::RegionExit(_) => SmallVec::new(),
579
580                CfgPoint::ControlNodeEntry(control_node) => {
581                    match &func_def_body.at(control_node).def().kind {
582                        // The backedge of a SPIR-V structured loop points to
583                        // the "loop header", i.e. the `Entry` of the `Loop`,
584                        // so that's where `body` `inputs` phis have to go.
585                        ControlNodeKind::Loop { initial_inputs, body, .. } => {
586                            let loop_body_def = func_def_body.at(*body).def();
587                            let loop_body_inputs = &loop_body_def.inputs;
588
589                            if !loop_body_inputs.is_empty() {
590                                region_inputs_source.insert(
591                                    *body,
592                                    RegionInputsSource::LoopHeaderPhis(control_node),
593                                );
594                            }
595
596                            loop_body_inputs
597                                .iter()
598                                .enumerate()
599                                .map(|(i, &ControlRegionInputDecl { attrs, ty })| {
600                                    Ok(Phi {
601                                        attrs,
602                                        ty,
603
604                                        result_id: alloc_id()?,
605                                        cases: FxIndexMap::default(),
606                                        default_value: Some(initial_inputs[i]),
607                                    })
608                                })
609                                .collect::<Result<_, _>>()?
610                        }
611                        _ => SmallVec::new(),
612                    }
613                }
614                CfgPoint::ControlNodeExit(control_node) => func_def_body
615                    .at(control_node)
616                    .def()
617                    .outputs
618                    .iter()
619                    .map(|&ControlNodeOutputDecl { attrs, ty }| {
620                        Ok(Phi {
621                            attrs,
622                            ty,
623
624                            result_id: alloc_id()?,
625                            cases: FxIndexMap::default(),
626                            default_value: None,
627                        })
628                    })
629                    .collect::<Result<_, _>>()?,
630            };
631
632            let insts = match point {
633                CfgPoint::ControlNodeEntry(control_node) => {
634                    match func_def_body.at(control_node).def().kind {
635                        ControlNodeKind::Block { insts } => [insts].into_iter().collect(),
636                        _ => SmallVec::new(),
637                    }
638                }
639                _ => SmallVec::new(),
640            };
641
642            // Get the terminator, or reconstruct it from structured control-flow.
643            let terminator = match (point, func_def_body.at(point_cursor).unique_successor()) {
644                // Exiting a `ControlRegion` w/o a structured parent.
645                (CfgPoint::RegionExit(region), None) => {
646                    let unstructured_terminator = func_def_body
647                        .unstructured_cfg
648                        .as_ref()
649                        .and_then(|cfg| cfg.control_inst_on_exit_from.get(region));
650                    if let Some(terminator) = unstructured_terminator {
651                        let cfg::ControlInst { attrs, kind, inputs, targets, target_inputs } =
652                            terminator;
653                        Terminator {
654                            attrs: *attrs,
655                            kind: Cow::Borrowed(kind),
656                            // FIXME(eddyb) borrow these whenever possible.
657                            inputs: inputs.clone(),
658                            targets: targets
659                                .iter()
660                                .map(|&target| CfgPoint::RegionEntry(target))
661                                .collect(),
662                            target_phi_values: target_inputs
663                                .iter()
664                                .map(|(&target, target_inputs)| {
665                                    (CfgPoint::RegionEntry(target), &target_inputs[..])
666                                })
667                                .collect(),
668                            merge: None,
669                        }
670                    } else {
671                        // Structured return out of the function body.
672                        assert!(region == func_def_body.body);
673                        Terminator {
674                            attrs: AttrSet::default(),
675                            kind: Cow::Owned(cfg::ControlInstKind::Return),
676                            inputs: func_def_body.at_body().def().outputs.clone(),
677                            targets: [].into_iter().collect(),
678                            target_phi_values: FxIndexMap::default(),
679                            merge: None,
680                        }
681                    }
682                }
683
684                // Entering a `ControlNode` with child `ControlRegion`s (or diverging).
685                (CfgPoint::ControlNodeEntry(control_node), None) => {
686                    let control_node_def = func_def_body.at(control_node).def();
687                    match &control_node_def.kind {
688                        ControlNodeKind::Block { .. } => {
689                            unreachable!()
690                        }
691
692                        ControlNodeKind::Select { kind, scrutinee, cases } => Terminator {
693                            attrs: AttrSet::default(),
694                            kind: Cow::Owned(cfg::ControlInstKind::SelectBranch(kind.clone())),
695                            inputs: [*scrutinee].into_iter().collect(),
696                            targets: cases
697                                .iter()
698                                .map(|&case| CfgPoint::RegionEntry(case))
699                                .collect(),
700                            target_phi_values: FxIndexMap::default(),
701                            merge: Some(Merge::Selection(CfgPoint::ControlNodeExit(control_node))),
702                        },
703
704                        ControlNodeKind::Loop { initial_inputs: _, body, repeat_condition: _ } => {
705                            Terminator {
706                                attrs: AttrSet::default(),
707                                kind: Cow::Owned(cfg::ControlInstKind::Branch),
708                                inputs: [].into_iter().collect(),
709                                targets: [CfgPoint::RegionEntry(*body)].into_iter().collect(),
710                                target_phi_values: FxIndexMap::default(),
711                                merge: Some(Merge::Loop {
712                                    loop_merge: CfgPoint::ControlNodeExit(control_node),
713                                    // NOTE(eddyb) see the note on `Merge::Loop`'s
714                                    // `loop_continue` field - in particular, for
715                                    // SPIR-T loops, we *could* pick any point
716                                    // before/after/between `body`'s `children`
717                                    // and it should be valid *but* that had to be
718                                    // reverted because it's only true in the absence
719                                    // of divergence within the loop body itself!
720                                    loop_continue: CfgPoint::RegionExit(*body),
721                                }),
722                            }
723                        }
724
725                        ControlNodeKind::ExitInvocation { kind, inputs } => Terminator {
726                            attrs: AttrSet::default(),
727                            kind: Cow::Owned(cfg::ControlInstKind::ExitInvocation(kind.clone())),
728                            inputs: inputs.clone(),
729                            targets: [].into_iter().collect(),
730                            target_phi_values: FxIndexMap::default(),
731                            merge: None,
732                        },
733                    }
734                }
735
736                // Exiting a `ControlRegion` to the parent `ControlNode`.
737                (CfgPoint::RegionExit(region), Some(parent_exit_cursor)) => {
738                    let region_outputs = Some(&func_def_body.at(region).def().outputs[..])
739                        .filter(|outputs| !outputs.is_empty());
740
741                    let parent_exit = parent_exit_cursor.point;
742                    let parent_node = match parent_exit {
743                        CfgPoint::ControlNodeExit(parent_node) => parent_node,
744                        _ => unreachable!(),
745                    };
746
747                    match func_def_body.at(parent_node).def().kind {
748                        ControlNodeKind::Block { .. } | ControlNodeKind::ExitInvocation { .. } => {
749                            unreachable!()
750                        }
751
752                        ControlNodeKind::Select { .. } => Terminator {
753                            attrs: AttrSet::default(),
754                            kind: Cow::Owned(cfg::ControlInstKind::Branch),
755                            inputs: [].into_iter().collect(),
756                            targets: [parent_exit].into_iter().collect(),
757                            target_phi_values: region_outputs
758                                .map(|outputs| (parent_exit, outputs))
759                                .into_iter()
760                                .collect(),
761                            merge: None,
762                        },
763
764                        ControlNodeKind::Loop { initial_inputs: _, body: _, repeat_condition } => {
765                            let backedge = CfgPoint::ControlNodeEntry(parent_node);
766                            let target_phi_values = region_outputs
767                                .map(|outputs| (backedge, outputs))
768                                .into_iter()
769                                .collect();
770
771                            let is_infinite_loop = match repeat_condition {
772                                Value::Const(cond) => match &cx[cond].kind {
773                                    ConstKind::SpvInst { spv_inst_and_const_inputs } => {
774                                        let (spv_inst, _const_inputs) =
775                                            &**spv_inst_and_const_inputs;
776                                        spv_inst.opcode == wk.OpConstantTrue
777                                    }
778                                    _ => false,
779                                },
780
781                                _ => false,
782                            };
783                            if is_infinite_loop {
784                                Terminator {
785                                    attrs: AttrSet::default(),
786                                    kind: Cow::Owned(cfg::ControlInstKind::Branch),
787                                    inputs: [].into_iter().collect(),
788                                    targets: [backedge].into_iter().collect(),
789                                    target_phi_values,
790                                    merge: None,
791                                }
792                            } else {
793                                Terminator {
794                                    attrs: AttrSet::default(),
795                                    kind: Cow::Owned(cfg::ControlInstKind::SelectBranch(
796                                        SelectionKind::BoolCond,
797                                    )),
798                                    inputs: [repeat_condition].into_iter().collect(),
799                                    targets: [backedge, parent_exit].into_iter().collect(),
800                                    target_phi_values,
801                                    merge: None,
802                                }
803                            }
804                        }
805                    }
806                }
807
808                // Siblings in the same `ControlRegion` (including the
809                // implied edge from a `Block`'s `Entry` to its `Exit`).
810                (_, Some(succ_cursor)) => Terminator {
811                    attrs: AttrSet::default(),
812                    kind: Cow::Owned(cfg::ControlInstKind::Branch),
813                    inputs: [].into_iter().collect(),
814                    targets: [succ_cursor.point].into_iter().collect(),
815                    target_phi_values: FxIndexMap::default(),
816                    merge: None,
817                },
818
819                // Impossible cases, they always return `(_, Some(_))`.
820                (CfgPoint::RegionEntry(_) | CfgPoint::ControlNodeExit(_), None) => {
821                    unreachable!()
822                }
823            };
824
825            blocks.insert(point, BlockLifting { phis, insts, terminator });
826
827            Ok(())
828        };
829        match &func_def_body.unstructured_cfg {
830            None => func_def_body.at_body().rev_post_order_try_for_each(visit_cfg_point)?,
831            Some(cfg) => {
832                for region in cfg.rev_post_order(func_def_body) {
833                    func_def_body.at(region).rev_post_order_try_for_each(&mut visit_cfg_point)?;
834                }
835            }
836        }
837
838        // Count the number of "uses" of each block (each incoming edge, plus
839        // `1` for the entry block), to help determine which blocks are part
840        // of a linear branch chain (and potentially fusable), later on.
841        //
842        // FIXME(eddyb) use `EntityOrientedDenseMap` here.
843        let mut use_counts = FxHashMap::<CfgPoint, usize>::default();
844        use_counts.reserve(blocks.len());
845        let all_edges = blocks.first().map(|(&entry_point, _)| entry_point).into_iter().chain(
846            blocks.values().flat_map(|block| {
847                block
848                    .terminator
849                    .merge
850                    .iter()
851                    .flat_map(|merge| {
852                        let (a, b) = match merge {
853                            Merge::Selection(a) => (a, None),
854                            Merge::Loop { loop_merge: a, loop_continue: b } => (a, Some(b)),
855                        };
856                        [a].into_iter().chain(b)
857                    })
858                    .chain(&block.terminator.targets)
859                    .copied()
860            }),
861        );
862        for target in all_edges {
863            *use_counts.entry(target).or_default() += 1;
864        }
865
866        // Fuse chains of linear branches, when there is no information being
867        // lost by the fusion. This is done in reverse order, so that in e.g.
868        // `a -> b -> c`, `b -> c` is fused first, then when the iteration
869        // reaches `a`, it sees `a -> bc` and can further fuse that into one
870        // `abc` block, without knowing about `b` and `c` themselves
871        // (this is possible because RPO will always output `[a, b, c]`, when
872        // `b` and `c` only have one predecessor each).
873        //
874        // FIXME(eddyb) while this could theoretically fuse certain kinds of
875        // merge blocks (mostly loop bodies) into their unique precedessor, that
876        // would require adjusting the `Merge` that points to them.
877        //
878        // HACK(eddyb) this takes advantage of `blocks` being an `IndexMap`,
879        // to iterate at the same time as mutating other entries.
880        for block_idx in (0..blocks.len()).rev() {
881            let BlockLifting { terminator: original_terminator, .. } = &blocks[block_idx];
882
883            let is_trivial_branch = {
884                let Terminator { attrs, kind, inputs, targets, target_phi_values, merge } =
885                    original_terminator;
886
887                *attrs == AttrSet::default()
888                    && matches!(**kind, cfg::ControlInstKind::Branch)
889                    && inputs.is_empty()
890                    && targets.len() == 1
891                    && target_phi_values.is_empty()
892                    && merge.is_none()
893            };
894
895            if is_trivial_branch {
896                let target = original_terminator.targets[0];
897                let target_use_count = use_counts.get_mut(&target).unwrap();
898
899                if *target_use_count == 1 {
900                    let BlockLifting {
901                        phis: ref target_phis,
902                        insts: ref mut extra_insts,
903                        terminator: ref mut new_terminator,
904                    } = blocks[&target];
905
906                    // FIXME(eddyb) check for block-level attributes, once/if
907                    // they start being tracked.
908                    if target_phis.is_empty() {
909                        let extra_insts = mem::take(extra_insts);
910                        let new_terminator = mem::replace(new_terminator, Terminator {
911                            attrs: Default::default(),
912                            kind: Cow::Owned(cfg::ControlInstKind::Unreachable),
913                            inputs: Default::default(),
914                            targets: Default::default(),
915                            target_phi_values: Default::default(),
916                            merge: None,
917                        });
918                        *target_use_count = 0;
919
920                        let combined_block = &mut blocks[block_idx];
921                        combined_block.insts.extend(extra_insts);
922                        combined_block.terminator = new_terminator;
923                    }
924                }
925            }
926        }
927
928        // Remove now-unused blocks.
929        blocks.retain(|point, _| use_counts.get(point).is_some_and(|&count| count > 0));
930
931        // Collect `OpPhi`s from other blocks' edges into each block.
932        //
933        // HACK(eddyb) this takes advantage of `blocks` being an `IndexMap`,
934        // to iterate at the same time as mutating other entries.
935        for source_block_idx in 0..blocks.len() {
936            let (&source_point, source_block) = blocks.get_index(source_block_idx).unwrap();
937            let targets = source_block.terminator.targets.clone();
938
939            for target in targets {
940                let source_values = {
941                    let (_, source_block) = blocks.get_index(source_block_idx).unwrap();
942                    source_block.terminator.target_phi_values.get(&target).copied()
943                };
944                let target_block = blocks.get_mut(&target).unwrap();
945                for (i, target_phi) in target_block.phis.iter_mut().enumerate() {
946                    use indexmap::map::Entry;
947
948                    let source_value =
949                        source_values.map(|values| values[i]).or(target_phi.default_value).unwrap();
950                    match target_phi.cases.entry(source_point) {
951                        Entry::Vacant(entry) => {
952                            entry.insert(source_value);
953                        }
954
955                        // NOTE(eddyb) the only reason duplicates are allowed,
956                        // is that `targets` may itself contain the same target
957                        // multiple times (which would result in the same value).
958                        Entry::Occupied(entry) => {
959                            assert!(*entry.get() == source_value);
960                        }
961                    }
962                }
963            }
964        }
965
966        let all_insts_with_output = blocks
967            .values()
968            .flat_map(|block| block.insts.iter().copied())
969            .flat_map(|insts| func_def_body.at(insts))
970            .filter(|&func_at_inst| cx[func_at_inst.def().form].output_type.is_some())
971            .map(|func_at_inst| func_at_inst.position);
972
973        Ok(Self {
974            func_id,
975            param_ids,
976            region_inputs_source,
977            data_inst_output_ids: all_insts_with_output
978                .map(|inst| Ok((inst, alloc_id()?)))
979                .collect::<Result<_, _>>()?,
980            label_ids: blocks
981                .keys()
982                .map(|&point| Ok((point, alloc_id()?)))
983                .collect::<Result<_, _>>()?,
984            blocks,
985        })
986    }
987}
988
989/// "Maybe-decorated "lazy" SPIR-V instruction, allowing separately emitting
990/// decorations from attributes, and the instruction itself, without eagerly
991/// allocating all the instructions.
992#[derive(Copy, Clone)]
993enum LazyInst<'a, 'b> {
994    Global(Global),
995    OpFunction {
996        func_id: spv::Id,
997        func_decl: &'a FuncDecl,
998    },
999    OpFunctionParameter {
1000        param_id: spv::Id,
1001        param: &'a FuncParam,
1002    },
1003    OpLabel {
1004        label_id: spv::Id,
1005    },
1006    OpPhi {
1007        parent_func: &'b FuncLifting<'a>,
1008        phi: &'b Phi,
1009    },
1010    DataInst {
1011        parent_func: &'b FuncLifting<'a>,
1012        result_id: Option<spv::Id>,
1013        data_inst_def: &'a DataInstDef,
1014    },
1015    Merge(Merge<spv::Id>),
1016    Terminator {
1017        parent_func: &'b FuncLifting<'a>,
1018        terminator: &'b Terminator<'a>,
1019    },
1020    OpFunctionEnd,
1021}
1022
1023impl LazyInst<'_, '_> {
1024    fn result_id_attrs_and_import(
1025        self,
1026        module: &Module,
1027        ids: &AllocatedIds<'_>,
1028    ) -> (Option<spv::Id>, AttrSet, Option<Import>) {
1029        let cx = module.cx_ref();
1030
1031        #[allow(clippy::match_same_arms)]
1032        match self {
1033            Self::Global(global) => {
1034                let (attrs, import) = match global {
1035                    Global::Type(ty) => (cx[ty].attrs, None),
1036                    Global::Const(ct) => {
1037                        let ct_def = &cx[ct];
1038                        match ct_def.kind {
1039                            ConstKind::PtrToGlobalVar(gv) => {
1040                                let gv_decl = &module.global_vars[gv];
1041                                let import = match gv_decl.def {
1042                                    DeclDef::Imported(import) => Some(import),
1043                                    DeclDef::Present(_) => None,
1044                                };
1045                                (gv_decl.attrs, import)
1046                            }
1047                            ConstKind::SpvInst { .. } => (ct_def.attrs, None),
1048
1049                            // Not inserted into `globals` while visiting.
1050                            ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(),
1051                        }
1052                    }
1053                };
1054                (Some(ids.globals[&global]), attrs, import)
1055            }
1056            Self::OpFunction { func_id, func_decl } => {
1057                let import = match func_decl.def {
1058                    DeclDef::Imported(import) => Some(import),
1059                    DeclDef::Present(_) => None,
1060                };
1061                (Some(func_id), func_decl.attrs, import)
1062            }
1063            Self::OpFunctionParameter { param_id, param } => (Some(param_id), param.attrs, None),
1064            Self::OpLabel { label_id } => (Some(label_id), AttrSet::default(), None),
1065            Self::OpPhi { parent_func: _, phi } => (Some(phi.result_id), phi.attrs, None),
1066            Self::DataInst { parent_func: _, result_id, data_inst_def } => {
1067                (result_id, data_inst_def.attrs, None)
1068            }
1069            Self::Merge(_) => (None, AttrSet::default(), None),
1070            Self::Terminator { parent_func: _, terminator } => (None, terminator.attrs, None),
1071            Self::OpFunctionEnd => (None, AttrSet::default(), None),
1072        }
1073    }
1074
1075    fn to_inst_and_attrs(
1076        self,
1077        module: &Module,
1078        ids: &AllocatedIds<'_>,
1079    ) -> (spv::InstWithIds, AttrSet) {
1080        let wk = &spec::Spec::get().well_known;
1081        let cx = module.cx_ref();
1082
1083        let value_to_id = |parent_func: &FuncLifting<'_>, v| match v {
1084            Value::Const(ct) => match cx[ct].kind {
1085                ConstKind::SpvStringLiteralForExtInst(s) => ids.debug_strings[&cx[s]],
1086
1087                _ => ids.globals[&Global::Const(ct)],
1088            },
1089            Value::ControlRegionInput { region, input_idx } => {
1090                let input_idx = usize::try_from(input_idx).unwrap();
1091                match parent_func.region_inputs_source.get(&region) {
1092                    Some(RegionInputsSource::FuncParams) => parent_func.param_ids[input_idx],
1093                    Some(&RegionInputsSource::LoopHeaderPhis(loop_node)) => {
1094                        parent_func.blocks[&CfgPoint::ControlNodeEntry(loop_node)].phis[input_idx]
1095                            .result_id
1096                    }
1097                    None => {
1098                        parent_func.blocks[&CfgPoint::RegionEntry(region)].phis[input_idx].result_id
1099                    }
1100                }
1101            }
1102            Value::ControlNodeOutput { control_node, output_idx } => {
1103                parent_func.blocks[&CfgPoint::ControlNodeExit(control_node)].phis
1104                    [usize::try_from(output_idx).unwrap()]
1105                .result_id
1106            }
1107            Value::DataInstOutput(inst) => parent_func.data_inst_output_ids[&inst],
1108        };
1109
1110        let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids);
1111        let inst = match self {
1112            Self::Global(global) => match global {
1113                Global::Type(ty) => match &cx[ty].kind {
1114                    TypeKind::SpvInst { spv_inst, type_and_const_inputs } => spv::InstWithIds {
1115                        without_ids: spv_inst.clone(),
1116                        result_type_id: None,
1117                        result_id,
1118                        ids: type_and_const_inputs
1119                            .iter()
1120                            .map(|&ty_or_ct| {
1121                                ids.globals[&match ty_or_ct {
1122                                    TypeOrConst::Type(ty) => Global::Type(ty),
1123                                    TypeOrConst::Const(ct) => Global::Const(ct),
1124                                }]
1125                            })
1126                            .collect(),
1127                    },
1128
1129                    // Not inserted into `globals` while visiting.
1130                    TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => unreachable!(),
1131                },
1132                Global::Const(ct) => {
1133                    let ct_def = &cx[ct];
1134                    match &ct_def.kind {
1135                        &ConstKind::PtrToGlobalVar(gv) => {
1136                            assert!(ct_def.attrs == AttrSet::default());
1137
1138                            let gv_decl = &module.global_vars[gv];
1139
1140                            assert!(ct_def.ty == gv_decl.type_of_ptr_to);
1141
1142                            let storage_class = match gv_decl.addr_space {
1143                                AddrSpace::Handles => {
1144                                    unreachable!(
1145                                        "`AddrSpace::Handles` should be legalized away before lifting"
1146                                    );
1147                                }
1148                                AddrSpace::SpvStorageClass(sc) => {
1149                                    spv::Imm::Short(wk.StorageClass, sc)
1150                                }
1151                            };
1152                            let initializer = match gv_decl.def {
1153                                DeclDef::Imported(_) => None,
1154                                DeclDef::Present(GlobalVarDefBody { initializer }) => initializer
1155                                    .map(|initializer| ids.globals[&Global::Const(initializer)]),
1156                            };
1157                            spv::InstWithIds {
1158                                without_ids: spv::Inst {
1159                                    opcode: wk.OpVariable,
1160                                    imms: iter::once(storage_class).collect(),
1161                                },
1162                                result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]),
1163                                result_id,
1164                                ids: initializer.into_iter().collect(),
1165                            }
1166                        }
1167
1168                        ConstKind::SpvInst { spv_inst_and_const_inputs } => {
1169                            let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
1170                            spv::InstWithIds {
1171                                without_ids: spv_inst.clone(),
1172                                result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]),
1173                                result_id,
1174                                ids: const_inputs
1175                                    .iter()
1176                                    .map(|&ct| ids.globals[&Global::Const(ct)])
1177                                    .collect(),
1178                            }
1179                        }
1180
1181                        // Not inserted into `globals` while visiting.
1182                        ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(),
1183                    }
1184                }
1185            },
1186            Self::OpFunction { func_id: _, func_decl } => {
1187                // FIXME(eddyb) make this less of a search and more of a
1188                // lookup by splitting attrs into key and value parts.
1189                let func_ctrl = cx[attrs]
1190                    .attrs
1191                    .iter()
1192                    .find_map(|attr| match *attr {
1193                        Attr::SpvBitflagsOperand(spv::Imm::Short(kind, word))
1194                            if kind == wk.FunctionControl =>
1195                        {
1196                            Some(word)
1197                        }
1198                        _ => None,
1199                    })
1200                    .unwrap_or(0);
1201
1202                spv::InstWithIds {
1203                    without_ids: spv::Inst {
1204                        opcode: wk.OpFunction,
1205                        imms: iter::once(spv::Imm::Short(wk.FunctionControl, func_ctrl)).collect(),
1206                    },
1207                    result_type_id: Some(ids.globals[&Global::Type(func_decl.ret_type)]),
1208                    result_id,
1209                    ids: iter::once(ids.globals[&Global::Type(func_decl.spv_func_type(cx))])
1210                        .collect(),
1211                }
1212            }
1213            Self::OpFunctionParameter { param_id: _, param } => spv::InstWithIds {
1214                without_ids: wk.OpFunctionParameter.into(),
1215                result_type_id: Some(ids.globals[&Global::Type(param.ty)]),
1216                result_id,
1217                ids: [].into_iter().collect(),
1218            },
1219            Self::OpLabel { label_id: _ } => spv::InstWithIds {
1220                without_ids: wk.OpLabel.into(),
1221                result_type_id: None,
1222                result_id,
1223                ids: [].into_iter().collect(),
1224            },
1225            Self::OpPhi { parent_func, phi } => spv::InstWithIds {
1226                without_ids: wk.OpPhi.into(),
1227                result_type_id: Some(ids.globals[&Global::Type(phi.ty)]),
1228                result_id: Some(phi.result_id),
1229                ids: phi
1230                    .cases
1231                    .iter()
1232                    .flat_map(|(&source_point, &v)| {
1233                        [value_to_id(parent_func, v), parent_func.label_ids[&source_point]]
1234                    })
1235                    .collect(),
1236            },
1237            Self::DataInst { parent_func, result_id: _, data_inst_def } => {
1238                let DataInstFormDef { kind, output_type } = &cx[data_inst_def.form];
1239                let (inst, extra_initial_id_operand) = match kind {
1240                    // Disallowed while visiting.
1241                    DataInstKind::QPtr(_) => unreachable!(),
1242
1243                    &DataInstKind::FuncCall(callee) => {
1244                        (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id))
1245                    }
1246                    DataInstKind::SpvInst(inst) => (inst.clone(), None),
1247                    &DataInstKind::SpvExtInst { ext_set, inst } => (
1248                        spv::Inst {
1249                            opcode: wk.OpExtInst,
1250                            imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst))
1251                                .collect(),
1252                        },
1253                        Some(ids.ext_inst_imports[&cx[ext_set]]),
1254                    ),
1255                };
1256                spv::InstWithIds {
1257                    without_ids: inst,
1258                    result_type_id: output_type.map(|ty| ids.globals[&Global::Type(ty)]),
1259                    result_id,
1260                    ids: extra_initial_id_operand
1261                        .into_iter()
1262                        .chain(data_inst_def.inputs.iter().map(|&v| value_to_id(parent_func, v)))
1263                        .collect(),
1264                }
1265            }
1266            Self::Merge(Merge::Selection(merge_label_id)) => spv::InstWithIds {
1267                without_ids: spv::Inst {
1268                    opcode: wk.OpSelectionMerge,
1269                    imms: [spv::Imm::Short(wk.SelectionControl, 0)].into_iter().collect(),
1270                },
1271                result_type_id: None,
1272                result_id: None,
1273                ids: [merge_label_id].into_iter().collect(),
1274            },
1275            Self::Merge(Merge::Loop {
1276                loop_merge: merge_label_id,
1277                loop_continue: continue_label_id,
1278            }) => spv::InstWithIds {
1279                without_ids: spv::Inst {
1280                    opcode: wk.OpLoopMerge,
1281                    imms: [spv::Imm::Short(wk.LoopControl, 0)].into_iter().collect(),
1282                },
1283                result_type_id: None,
1284                result_id: None,
1285                ids: [merge_label_id, continue_label_id].into_iter().collect(),
1286            },
1287            Self::Terminator { parent_func, terminator } => {
1288                let inst = match &*terminator.kind {
1289                    cfg::ControlInstKind::Unreachable => wk.OpUnreachable.into(),
1290                    cfg::ControlInstKind::Return => {
1291                        if terminator.inputs.is_empty() {
1292                            wk.OpReturn.into()
1293                        } else {
1294                            wk.OpReturnValue.into()
1295                        }
1296                    }
1297                    cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(
1298                        inst,
1299                    )) => inst.clone(),
1300
1301                    cfg::ControlInstKind::Branch => wk.OpBranch.into(),
1302
1303                    cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) => {
1304                        wk.OpBranchConditional.into()
1305                    }
1306                    cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst(inst)) => {
1307                        inst.clone()
1308                    }
1309                };
1310                spv::InstWithIds {
1311                    without_ids: inst,
1312                    result_type_id: None,
1313                    result_id: None,
1314                    ids: terminator
1315                        .inputs
1316                        .iter()
1317                        .map(|&v| value_to_id(parent_func, v))
1318                        .chain(
1319                            terminator.targets.iter().map(|&target| parent_func.label_ids[&target]),
1320                        )
1321                        .collect(),
1322                }
1323            }
1324            Self::OpFunctionEnd => spv::InstWithIds {
1325                without_ids: wk.OpFunctionEnd.into(),
1326                result_type_id: None,
1327                result_id: None,
1328                ids: [].into_iter().collect(),
1329            },
1330        };
1331        (inst, attrs)
1332    }
1333}
1334
1335impl Module {
1336    pub fn lift_to_spv_file(&self, path: impl AsRef<Path>) -> io::Result<()> {
1337        self.lift_to_spv_module_emitter()?.write_to_spv_file(path)
1338    }
1339
1340    pub fn lift_to_spv_module_emitter(&self) -> io::Result<spv::write::ModuleEmitter> {
1341        let spv_spec = spec::Spec::get();
1342        let wk = &spv_spec.well_known;
1343
1344        let cx = self.cx();
1345        let (dialect, debug_info) = match (&self.dialect, &self.debug_info) {
1346            (ModuleDialect::Spv(dialect), ModuleDebugInfo::Spv(debug_info)) => {
1347                (dialect, debug_info)
1348            }
1349
1350            // FIXME(eddyb) support by computing some valid "minimum viable"
1351            // `spv::Dialect`, or by taking it as additional input.
1352            #[allow(unreachable_patterns)]
1353            _ => {
1354                return Err(io::Error::new(io::ErrorKind::InvalidData, "not a SPIR-V module"));
1355            }
1356        };
1357
1358        // Collect uses scattered throughout the module, that require def IDs.
1359        let mut needs_ids_collector = NeedsIdsCollector {
1360            cx: &cx,
1361            module: self,
1362            ext_inst_imports: BTreeSet::new(),
1363            debug_strings: BTreeSet::new(),
1364            globals: FxIndexSet::default(),
1365            data_inst_forms_seen: FxIndexSet::default(),
1366            global_vars_seen: FxIndexSet::default(),
1367            funcs: FxIndexSet::default(),
1368        };
1369        needs_ids_collector.visit_module(self);
1370
1371        // Because `GlobalVar`s are given IDs by the `Const`s that point to them
1372        // (i.e. `ConstKind::PtrToGlobalVar`), any `GlobalVar`s in other positions
1373        // require extra care to ensure the ID-giving `Const` is visited.
1374        let global_var_to_id_giving_global = |gv| {
1375            let type_of_ptr_to_global_var = self.global_vars[gv].type_of_ptr_to;
1376            let ptr_to_global_var = cx.intern(ConstDef {
1377                attrs: AttrSet::default(),
1378                ty: type_of_ptr_to_global_var,
1379                kind: ConstKind::PtrToGlobalVar(gv),
1380            });
1381            Global::Const(ptr_to_global_var)
1382        };
1383        for &gv in &needs_ids_collector.global_vars_seen {
1384            needs_ids_collector.globals.insert(global_var_to_id_giving_global(gv));
1385        }
1386
1387        // IDs can be allocated once we have the full sets needing them, whether
1388        // sorted by contents, or ordered by the first occurence in the module.
1389        let mut id_bound = NonZeroU32::MIN;
1390        let ids = needs_ids_collector.alloc_ids(|| {
1391            let id = id_bound;
1392
1393            match id_bound.checked_add(1) {
1394                Some(new_bound) => {
1395                    id_bound = new_bound;
1396                    Ok(id)
1397                }
1398                None => Err(io::Error::new(
1399                    io::ErrorKind::InvalidData,
1400                    "ID bound of SPIR-V module doesn't fit in 32 bits",
1401                )),
1402            }
1403        })?;
1404
1405        // HACK(eddyb) allow `move` closures below to reference `cx` or `ids`
1406        // without causing unwanted moves out of them.
1407        let (cx, ids) = (&*cx, &ids);
1408
1409        let global_and_func_insts = ids.globals.keys().copied().map(LazyInst::Global).chain(
1410            ids.funcs.iter().flat_map(|(&func, func_lifting)| {
1411                let func_decl = &self.funcs[func];
1412                let func_def_body = match &func_decl.def {
1413                    DeclDef::Imported(_) => None,
1414                    DeclDef::Present(def) => Some(def),
1415                };
1416
1417                iter::once(LazyInst::OpFunction { func_id: func_lifting.func_id, func_decl })
1418                    .chain(func_lifting.param_ids.iter().zip(&func_decl.params).map(
1419                        |(&param_id, param)| LazyInst::OpFunctionParameter { param_id, param },
1420                    ))
1421                    .chain(func_lifting.blocks.iter().flat_map(move |(point, block)| {
1422                        let BlockLifting { phis, insts, terminator } = block;
1423
1424                        iter::once(LazyInst::OpLabel { label_id: func_lifting.label_ids[point] })
1425                            .chain(
1426                                phis.iter()
1427                                    .map(|phi| LazyInst::OpPhi { parent_func: func_lifting, phi }),
1428                            )
1429                            .chain(
1430                                insts
1431                                    .iter()
1432                                    .copied()
1433                                    .flat_map(move |insts| func_def_body.unwrap().at(insts))
1434                                    .map(move |func_at_inst| {
1435                                        let data_inst_def = func_at_inst.def();
1436                                        LazyInst::DataInst {
1437                                            parent_func: func_lifting,
1438                                            result_id: cx[data_inst_def.form].output_type.map(
1439                                                |_| {
1440                                                    func_lifting.data_inst_output_ids
1441                                                        [&func_at_inst.position]
1442                                                },
1443                                            ),
1444                                            data_inst_def,
1445                                        }
1446                                    }),
1447                            )
1448                            .chain(terminator.merge.map(|merge| {
1449                                LazyInst::Merge(match merge {
1450                                    Merge::Selection(merge) => {
1451                                        Merge::Selection(func_lifting.label_ids[&merge])
1452                                    }
1453                                    Merge::Loop { loop_merge, loop_continue } => Merge::Loop {
1454                                        loop_merge: func_lifting.label_ids[&loop_merge],
1455                                        loop_continue: func_lifting.label_ids[&loop_continue],
1456                                    },
1457                                })
1458                            }))
1459                            .chain([LazyInst::Terminator { parent_func: func_lifting, terminator }])
1460                    }))
1461                    .chain([LazyInst::OpFunctionEnd])
1462            }),
1463        );
1464
1465        let reserved_inst_schema = 0;
1466        let header = [
1467            spv_spec.magic,
1468            (u32::from(dialect.version_major) << 16) | (u32::from(dialect.version_minor) << 8),
1469            debug_info.original_generator_magic.map_or(0, |x| x.get()),
1470            id_bound.get(),
1471            reserved_inst_schema,
1472        ];
1473
1474        let mut emitter = spv::write::ModuleEmitter::with_header(header);
1475
1476        for cap_inst in dialect.capability_insts() {
1477            emitter.push_inst(&cap_inst)?;
1478        }
1479        for ext_inst in dialect.extension_insts() {
1480            emitter.push_inst(&ext_inst)?;
1481        }
1482        for (&name, &id) in &ids.ext_inst_imports {
1483            emitter.push_inst(&spv::InstWithIds {
1484                without_ids: spv::Inst {
1485                    opcode: wk.OpExtInstImport,
1486                    imms: spv::encode_literal_string(name).collect(),
1487                },
1488                result_type_id: None,
1489                result_id: Some(id),
1490                ids: [].into_iter().collect(),
1491            })?;
1492        }
1493        emitter.push_inst(&spv::InstWithIds {
1494            without_ids: spv::Inst {
1495                opcode: wk.OpMemoryModel,
1496                imms: [
1497                    spv::Imm::Short(wk.AddressingModel, dialect.addressing_model),
1498                    spv::Imm::Short(wk.MemoryModel, dialect.memory_model),
1499                ]
1500                .into_iter()
1501                .collect(),
1502            },
1503            result_type_id: None,
1504            result_id: None,
1505            ids: [].into_iter().collect(),
1506        })?;
1507
1508        // Collect the various sources of attributes.
1509        let mut entry_point_insts = vec![];
1510        let mut execution_mode_insts = vec![];
1511        let mut debug_name_insts = vec![];
1512        let mut decoration_insts = vec![];
1513
1514        for lazy_inst in global_and_func_insts.clone() {
1515            let (result_id, attrs, import) = lazy_inst.result_id_attrs_and_import(self, ids);
1516
1517            for attr in cx[attrs].attrs.iter() {
1518                match attr {
1519                    Attr::Diagnostics(_)
1520                    | Attr::QPtr(_)
1521                    | Attr::SpvDebugLine { .. }
1522                    | Attr::SpvBitflagsOperand(_) => {}
1523                    Attr::SpvAnnotation(inst @ spv::Inst { opcode, .. }) => {
1524                        let target_id = result_id.expect(
1525                            "FIXME: it shouldn't be possible to attach \
1526                                 attributes to instructions without an output",
1527                        );
1528
1529                        let inst = spv::InstWithIds {
1530                            without_ids: inst.clone(),
1531                            result_type_id: None,
1532                            result_id: None,
1533                            ids: iter::once(target_id).collect(),
1534                        };
1535
1536                        if [wk.OpExecutionMode, wk.OpExecutionModeId].contains(opcode) {
1537                            execution_mode_insts.push(inst);
1538                        } else if [wk.OpName, wk.OpMemberName].contains(opcode) {
1539                            debug_name_insts.push(inst);
1540                        } else {
1541                            decoration_insts.push(inst);
1542                        }
1543                    }
1544                }
1545
1546                if let Some(import) = import {
1547                    let target_id = result_id.unwrap();
1548                    match import {
1549                        Import::LinkName(name) => {
1550                            decoration_insts.push(spv::InstWithIds {
1551                                without_ids: spv::Inst {
1552                                    opcode: wk.OpDecorate,
1553                                    imms: iter::once(spv::Imm::Short(
1554                                        wk.Decoration,
1555                                        wk.LinkageAttributes,
1556                                    ))
1557                                    .chain(spv::encode_literal_string(&cx[name]))
1558                                    .chain([spv::Imm::Short(wk.LinkageType, wk.Import)])
1559                                    .collect(),
1560                                },
1561                                result_type_id: None,
1562                                result_id: None,
1563                                ids: iter::once(target_id).collect(),
1564                            });
1565                        }
1566                    }
1567                }
1568            }
1569        }
1570
1571        for (export_key, &exportee) in &self.exports {
1572            let target_id = match exportee {
1573                Exportee::GlobalVar(gv) => ids.globals[&global_var_to_id_giving_global(gv)],
1574                Exportee::Func(func) => ids.funcs[&func].func_id,
1575            };
1576            match export_key {
1577                &ExportKey::LinkName(name) => {
1578                    decoration_insts.push(spv::InstWithIds {
1579                        without_ids: spv::Inst {
1580                            opcode: wk.OpDecorate,
1581                            imms: iter::once(spv::Imm::Short(wk.Decoration, wk.LinkageAttributes))
1582                                .chain(spv::encode_literal_string(&cx[name]))
1583                                .chain([spv::Imm::Short(wk.LinkageType, wk.Export)])
1584                                .collect(),
1585                        },
1586                        result_type_id: None,
1587                        result_id: None,
1588                        ids: iter::once(target_id).collect(),
1589                    });
1590                }
1591                ExportKey::SpvEntryPoint { imms, interface_global_vars } => {
1592                    entry_point_insts.push(spv::InstWithIds {
1593                        without_ids: spv::Inst {
1594                            opcode: wk.OpEntryPoint,
1595                            imms: imms.iter().copied().collect(),
1596                        },
1597                        result_type_id: None,
1598                        result_id: None,
1599                        ids: iter::once(target_id)
1600                            .chain(
1601                                interface_global_vars
1602                                    .iter()
1603                                    .map(|&gv| ids.globals[&global_var_to_id_giving_global(gv)]),
1604                            )
1605                            .collect(),
1606                    });
1607                }
1608            }
1609        }
1610
1611        // FIXME(eddyb) maybe make a helper for `push_inst` with an iterator?
1612        for entry_point_inst in entry_point_insts {
1613            emitter.push_inst(&entry_point_inst)?;
1614        }
1615        for execution_mode_inst in execution_mode_insts {
1616            emitter.push_inst(&execution_mode_inst)?;
1617        }
1618
1619        for (&s, &id) in &ids.debug_strings {
1620            emitter.push_inst(&spv::InstWithIds {
1621                without_ids: spv::Inst {
1622                    opcode: wk.OpString,
1623                    imms: spv::encode_literal_string(s).collect(),
1624                },
1625                result_type_id: None,
1626                result_id: Some(id),
1627                ids: [].into_iter().collect(),
1628            })?;
1629        }
1630        for (lang, sources) in &debug_info.source_languages {
1631            let lang_imms = || {
1632                [
1633                    spv::Imm::Short(wk.SourceLanguage, lang.lang),
1634                    spv::Imm::Short(wk.LiteralInteger, lang.version),
1635                ]
1636                .into_iter()
1637            };
1638            if sources.file_contents.is_empty() {
1639                emitter.push_inst(&spv::InstWithIds {
1640                    without_ids: spv::Inst { opcode: wk.OpSource, imms: lang_imms().collect() },
1641                    result_type_id: None,
1642                    result_id: None,
1643                    ids: [].into_iter().collect(),
1644                })?;
1645            } else {
1646                for (&file, contents) in &sources.file_contents {
1647                    // The maximum word count is `2**16 - 1`, the first word is
1648                    // taken up by the opcode & word count, and one extra byte is
1649                    // taken up by the nil byte at the end of the LiteralString.
1650                    const MAX_OP_SOURCE_CONT_CONTENTS_LEN: usize = (0xffff - 1) * 4 - 1;
1651
1652                    // `OpSource` has 3 more operands than `OpSourceContinued`,
1653                    // and each of them take up exactly one word.
1654                    const MAX_OP_SOURCE_CONTENTS_LEN: usize =
1655                        MAX_OP_SOURCE_CONT_CONTENTS_LEN - 3 * 4;
1656
1657                    let (contents_initial, mut contents_rest) =
1658                        contents.split_at(contents.len().min(MAX_OP_SOURCE_CONTENTS_LEN));
1659
1660                    emitter.push_inst(&spv::InstWithIds {
1661                        without_ids: spv::Inst {
1662                            opcode: wk.OpSource,
1663                            imms: lang_imms()
1664                                .chain(spv::encode_literal_string(contents_initial))
1665                                .collect(),
1666                        },
1667                        result_type_id: None,
1668                        result_id: None,
1669                        ids: iter::once(ids.debug_strings[&cx[file]]).collect(),
1670                    })?;
1671
1672                    while !contents_rest.is_empty() {
1673                        // FIXME(eddyb) test with UTF-8! this `split_at` should
1674                        // actually take *less* than the full possible size, to
1675                        // avoid cutting a UTF-8 sequence.
1676                        let (cont_chunk, rest) = contents_rest
1677                            .split_at(contents_rest.len().min(MAX_OP_SOURCE_CONT_CONTENTS_LEN));
1678                        contents_rest = rest;
1679
1680                        emitter.push_inst(&spv::InstWithIds {
1681                            without_ids: spv::Inst {
1682                                opcode: wk.OpSourceContinued,
1683                                imms: spv::encode_literal_string(cont_chunk).collect(),
1684                            },
1685                            result_type_id: None,
1686                            result_id: None,
1687                            ids: [].into_iter().collect(),
1688                        })?;
1689                    }
1690                }
1691            }
1692        }
1693        for ext_inst in debug_info.source_extension_insts() {
1694            emitter.push_inst(&ext_inst)?;
1695        }
1696        for debug_name_inst in debug_name_insts {
1697            emitter.push_inst(&debug_name_inst)?;
1698        }
1699        for mod_proc_inst in debug_info.module_processed_insts() {
1700            emitter.push_inst(&mod_proc_inst)?;
1701        }
1702
1703        for decoration_inst in decoration_insts {
1704            emitter.push_inst(&decoration_inst)?;
1705        }
1706
1707        let mut current_debug_line = None;
1708        let mut current_block_id = None; // HACK(eddyb) for `current_debug_line` resets.
1709        for lazy_inst in global_and_func_insts {
1710            let (inst, attrs) = lazy_inst.to_inst_and_attrs(self, ids);
1711
1712            // Reset line debuginfo when crossing/leaving blocks.
1713            let new_block_id = if inst.opcode == wk.OpLabel {
1714                Some(inst.result_id.unwrap())
1715            } else if inst.opcode == wk.OpFunctionEnd {
1716                None
1717            } else {
1718                current_block_id
1719            };
1720            if current_block_id != new_block_id {
1721                current_debug_line = None;
1722            }
1723            current_block_id = new_block_id;
1724
1725            // Determine whether to emit `OpLine`/`OpNoLine` before `inst`,
1726            // in order to end up with the expected line debuginfo.
1727            // FIXME(eddyb) make this less of a search and more of a
1728            // lookup by splitting attrs into key and value parts.
1729            let new_debug_line = cx[attrs].attrs.iter().find_map(|attr| match *attr {
1730                Attr::SpvDebugLine { file_path, line, col } => {
1731                    Some((ids.debug_strings[&cx[file_path.0]], line, col))
1732                }
1733                _ => None,
1734            });
1735            if current_debug_line != new_debug_line {
1736                let (opcode, imms, ids) = match new_debug_line {
1737                    Some((file_path_id, line, col)) => (
1738                        wk.OpLine,
1739                        [
1740                            spv::Imm::Short(wk.LiteralInteger, line),
1741                            spv::Imm::Short(wk.LiteralInteger, col),
1742                        ]
1743                        .into_iter()
1744                        .collect(),
1745                        iter::once(file_path_id).collect(),
1746                    ),
1747                    None => (wk.OpNoLine, [].into_iter().collect(), [].into_iter().collect()),
1748                };
1749                emitter.push_inst(&spv::InstWithIds {
1750                    without_ids: spv::Inst { opcode, imms },
1751                    result_type_id: None,
1752                    result_id: None,
1753                    ids,
1754                })?;
1755            }
1756            current_debug_line = new_debug_line;
1757
1758            emitter.push_inst(&inst)?;
1759        }
1760
1761        Ok(emitter)
1762    }
1763}