rustc_codegen_spirv/linker/spirt_passes/
controlflow.rs

1//! SPIR-T passes related to control-flow.
2
3use crate::custom_insts::{self, CustomInst, CustomOp};
4use smallvec::SmallVec;
5use spirt::func_at::FuncAt;
6use spirt::{
7    Attr, AttrSet, ConstDef, ConstKind, ControlNodeKind, DataInstFormDef, DataInstKind, DeclDef,
8    EntityDefs, ExportKey, Exportee, Module, Type, TypeDef, TypeKind, TypeOrConst, Value, cfg, spv,
9};
10use std::fmt::Write as _;
11
12/// Replace our custom extended instruction `Abort`s with standard `OpReturn`s,
13/// but only in entry-points (and only before CFG structurization).
14//
15// FIXME(eddyb) no longer relying on structurization, try porting this
16// to replace custom aborts in `Block`s and inject `ExitInvocation`s
17// after them (truncating the `Block` and/or parent region if necessary).
18pub fn convert_custom_aborts_to_unstructured_returns_in_entry_points(
19    linker_options: &crate::linker::Options,
20    module: &mut Module,
21) {
22    // HACK(eddyb) this shouldn't be the place to parse `abort_strategy`.
23    enum Strategy {
24        Unreachable,
25        DebugPrintf { inputs: bool, backtrace: bool },
26    }
27    let abort_strategy = linker_options.abort_strategy.as_ref().map(|s| {
28        if s == "unreachable" {
29            return Strategy::Unreachable;
30        }
31        if let Some(s) = s.strip_prefix("debug-printf") {
32            let (inputs, s) = s.strip_prefix("+inputs").map_or((false, s), |s| (true, s));
33            let (backtrace, s) = s
34                .strip_prefix("+backtrace")
35                .map_or((false, s), |s| (true, s));
36            if s.is_empty() {
37                return Strategy::DebugPrintf { inputs, backtrace };
38            }
39        }
40        panic!("unknown `--abort-strategy={s}");
41    });
42
43    let cx = &module.cx();
44    let wk = &super::SpvSpecWithExtras::get().well_known;
45
46    // HACK(eddyb) deduplicate with `diagnostics`.
47    let name_from_attrs = |attrs: AttrSet| {
48        cx[attrs].attrs.iter().find_map(|attr| match attr {
49            Attr::SpvAnnotation(spv_inst) if spv_inst.opcode == wk.OpName => Some(
50                super::diagnostics::decode_spv_lit_str_with(&spv_inst.imms, |name| {
51                    name.to_string()
52                }),
53            ),
54            _ => None,
55        })
56    };
57
58    let custom_ext_inst_set = cx.intern(&custom_insts::CUSTOM_EXT_INST_SET[..]);
59
60    for (export_key, exportee) in &module.exports {
61        let (entry_point_imms, interface_global_vars, func) = match (export_key, exportee) {
62            (
63                ExportKey::SpvEntryPoint {
64                    imms,
65                    interface_global_vars,
66                },
67                &Exportee::Func(func),
68            ) => (imms, interface_global_vars, func),
69            _ => continue,
70        };
71
72        let func_decl = &mut module.funcs[func];
73        assert!(match &cx[func_decl.ret_type].kind {
74            TypeKind::SpvInst { spv_inst, .. } => spv_inst.opcode == wk.OpTypeVoid,
75            _ => false,
76        });
77
78        let func_def_body = match &mut func_decl.def {
79            DeclDef::Present(def) => def,
80            DeclDef::Imported(_) => continue,
81        };
82
83        let debug_printf_context_fmt_str;
84        let mut debug_printf_context_inputs = SmallVec::<[_; 4]>::new();
85        if let Some(Strategy::DebugPrintf { inputs, .. }) = abort_strategy {
86            let mut fmt = String::new();
87
88            match entry_point_imms[..] {
89                [spv::Imm::Short(em_kind, _), ref name_imms @ ..] => {
90                    assert_eq!(em_kind, wk.ExecutionModel);
91                    super::diagnostics::decode_spv_lit_str_with(name_imms, |name| {
92                        fmt += &name.replace('%', "%%");
93                    });
94                }
95                _ => unreachable!(),
96            }
97            fmt += "(";
98
99            // Collect entry-point inputs `OpLoad`ed by the entry block.
100            // HACK(eddyb) this relies on Rust-GPU always eagerly loading inputs.
101            let loaded_inputs = func_def_body
102                .at(func_def_body
103                    .at_body()
104                    .at_children()
105                    .into_iter()
106                    .next()
107                    .and_then(|func_at_first_node| match func_at_first_node.def().kind {
108                        ControlNodeKind::Block { insts } => Some(insts),
109                        _ => None,
110                    })
111                    .unwrap_or_default())
112                .into_iter()
113                .filter_map(|func_at_inst| {
114                    let data_inst_def = func_at_inst.def();
115                    let data_inst_form_def = &cx[data_inst_def.form];
116                    if let DataInstKind::SpvInst(spv_inst) = &data_inst_form_def.kind
117                        && spv_inst.opcode == wk.OpLoad
118                        && let Value::Const(ct) = data_inst_def.inputs[0]
119                        && let ConstKind::PtrToGlobalVar(gv) = cx[ct].kind
120                        && interface_global_vars.contains(&gv)
121                    {
122                        return Some((
123                            gv,
124                            data_inst_form_def.output_type.unwrap(),
125                            Value::DataInstOutput(func_at_inst.position),
126                        ));
127                    }
128                    None
129                });
130            if inputs {
131                let mut first_input = true;
132                for (gv, ty, value) in loaded_inputs {
133                    let scalar_type = |ty: Type| match &cx[ty].kind {
134                        TypeKind::SpvInst { spv_inst, .. } => match spv_inst.imms[..] {
135                            [spv::Imm::Short(_, 32), spv::Imm::Short(_, signedness)]
136                                if spv_inst.opcode == wk.OpTypeInt =>
137                            {
138                                Some(if signedness != 0 { "i" } else { "u" })
139                            }
140                            [spv::Imm::Short(_, 32)] if spv_inst.opcode == wk.OpTypeFloat => {
141                                Some("f")
142                            }
143                            _ => None,
144                        },
145                        _ => None,
146                    };
147                    let vector_or_scalar_type = |ty: Type| {
148                        let ty_def = &cx[ty];
149                        match &ty_def.kind {
150                            TypeKind::SpvInst {
151                                spv_inst,
152                                type_and_const_inputs,
153                            } if spv_inst.opcode == wk.OpTypeVector => {
154                                match (&type_and_const_inputs[..], &spv_inst.imms[..]) {
155                                    (
156                                        &[TypeOrConst::Type(elem)],
157                                        &[spv::Imm::Short(_, vlen @ 2..=4)],
158                                    ) => Some((scalar_type(elem)?, Some(vlen))),
159                                    _ => None,
160                                }
161                            }
162                            _ => Some((scalar_type(ty)?, None)),
163                        }
164                    };
165                    if let Some((scalar_fmt, vlen)) = vector_or_scalar_type(ty) {
166                        if !first_input {
167                            fmt += ", ";
168                        }
169                        first_input = false;
170
171                        if let Some(name) = name_from_attrs(module.global_vars[gv].attrs) {
172                            fmt += &name.replace('%', "%%");
173                            fmt += " = ";
174                        }
175                        match vlen {
176                            Some(vlen) => write!(fmt, "vec{vlen}(%v{vlen}{scalar_fmt})").unwrap(),
177                            None => write!(fmt, "%{scalar_fmt}").unwrap(),
178                        }
179                        debug_printf_context_inputs.push(value);
180                    }
181                }
182            }
183
184            fmt += ")";
185
186            debug_printf_context_fmt_str = fmt;
187        } else {
188            debug_printf_context_fmt_str = String::new();
189        }
190
191        let rpo_regions = func_def_body
192            .unstructured_cfg
193            .as_ref()
194            .expect("Abort->OpReturn can only be done on unstructured CFGs")
195            .rev_post_order(func_def_body);
196        for region in rpo_regions {
197            let region_def = &func_def_body.control_regions[region];
198            let control_node_def = match region_def.children.iter().last {
199                Some(last_node) => &mut func_def_body.control_nodes[last_node],
200                _ => continue,
201            };
202            let block_insts = match &mut control_node_def.kind {
203                ControlNodeKind::Block { insts } => insts,
204                _ => continue,
205            };
206
207            let terminator = &mut func_def_body
208                .unstructured_cfg
209                .as_mut()
210                .unwrap()
211                .control_inst_on_exit_from[region];
212            match terminator.kind {
213                cfg::ControlInstKind::Unreachable => {}
214                _ => continue,
215            }
216
217            // HACK(eddyb) this allows accessing the `DataInst` iterator while
218            // mutably borrowing other parts of `FuncDefBody`.
219            let func_at_block_insts = FuncAt {
220                control_nodes: &EntityDefs::new(),
221                control_regions: &EntityDefs::new(),
222                data_insts: &func_def_body.data_insts,
223
224                position: *block_insts,
225            };
226            let block_insts_maybe_custom = func_at_block_insts.into_iter().map(|func_at_inst| {
227                let data_inst_def = func_at_inst.def();
228                (
229                    func_at_inst,
230                    match cx[data_inst_def.form].kind {
231                        DataInstKind::SpvExtInst { ext_set, inst }
232                            if ext_set == custom_ext_inst_set =>
233                        {
234                            Some(CustomOp::decode(inst).with_operands(&data_inst_def.inputs))
235                        }
236                        _ => None,
237                    },
238                )
239            });
240            let custom_terminator_inst = block_insts_maybe_custom
241                .clone()
242                .rev()
243                .take_while(|(_, custom)| custom.is_some())
244                .map(|(func_at_inst, custom)| (func_at_inst, custom.unwrap()))
245                .find(|(_, custom)| !custom.op().is_debuginfo())
246                .filter(|(_, custom)| custom.op().is_terminator());
247            if let Some((
248                func_at_abort_inst,
249                CustomInst::Abort {
250                    kind: abort_kind,
251                    message_debug_printf,
252                },
253            )) = custom_terminator_inst
254            {
255                let abort_inst = func_at_abort_inst.position;
256                terminator.kind = cfg::ControlInstKind::ExitInvocation(
257                    cfg::ExitInvocationKind::SpvInst(wk.OpReturn.into()),
258                );
259
260                match abort_strategy {
261                    Some(Strategy::Unreachable) => {
262                        terminator.kind = cfg::ControlInstKind::Unreachable;
263                    }
264                    Some(Strategy::DebugPrintf {
265                        inputs: _,
266                        backtrace,
267                    }) => {
268                        let const_kind = |v: Value| match v {
269                            Value::Const(ct) => &cx[ct].kind,
270                            _ => unreachable!(),
271                        };
272                        let const_str = |v: Value| match const_kind(v) {
273                            &ConstKind::SpvStringLiteralForExtInst(s) => s,
274                            _ => unreachable!(),
275                        };
276                        let const_u32 = |v: Value| match const_kind(v) {
277                            ConstKind::SpvInst {
278                                spv_inst_and_const_inputs,
279                            } => {
280                                let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs;
281                                assert!(spv_inst.opcode == wk.OpConstant);
282                                match spv_inst.imms[..] {
283                                    [spv::Imm::Short(_, x)] => x,
284                                    _ => unreachable!(),
285                                }
286                            }
287                            _ => unreachable!(),
288                        };
289                        let mk_const_str = |s| {
290                            cx.intern(ConstDef {
291                                attrs: Default::default(),
292                                ty: cx.intern(TypeDef {
293                                    attrs: Default::default(),
294                                    kind: TypeKind::SpvStringLiteralForExtInst,
295                                }),
296                                kind: ConstKind::SpvStringLiteralForExtInst(s),
297                            })
298                        };
299
300                        let mut current_debug_src_loc = None;
301                        let mut call_stack = SmallVec::<[_; 8]>::new();
302                        let block_insts_custom = block_insts_maybe_custom
303                            .filter_map(|(func_at_inst, custom)| Some((func_at_inst, custom?)));
304                        for (func_at_inst, custom) in block_insts_custom {
305                            // Stop at the abort, that we don't undo its debug context.
306                            if func_at_inst.position == abort_inst {
307                                break;
308                            }
309
310                            match custom {
311                                CustomInst::SetDebugSrcLoc {
312                                    file,
313                                    line_start,
314                                    line_end: _,
315                                    col_start,
316                                    col_end: _,
317                                } => {
318                                    current_debug_src_loc = Some((
319                                        &cx[const_str(file)],
320                                        const_u32(line_start),
321                                        const_u32(col_start),
322                                    ));
323                                }
324                                CustomInst::ClearDebugSrcLoc => current_debug_src_loc = None,
325                                CustomInst::PushInlinedCallFrame { callee_name } => {
326                                    if backtrace {
327                                        call_stack.push((
328                                            current_debug_src_loc.take(),
329                                            const_str(callee_name),
330                                        ));
331                                    }
332                                }
333                                CustomInst::PopInlinedCallFrame => {
334                                    if let Some((callsite_debug_src_loc, _)) = call_stack.pop() {
335                                        current_debug_src_loc = callsite_debug_src_loc;
336                                    }
337                                }
338                                CustomInst::Abort { .. } => {}
339                            }
340                        }
341
342                        let mut fmt = String::new();
343
344                        let (message_debug_printf_fmt_str, message_debug_printf_args) =
345                            message_debug_printf
346                                .split_first()
347                                .map(|(&fmt_str, args)| (&cx[const_str(fmt_str)], args))
348                                .unwrap_or_default();
349
350                        let fmt_dbg_src_loc = |(file, line, col)| {
351                            // FIXME(eddyb) figure out what is going on with
352                            // these column number conventions, below is a
353                            // related comment from `spirt::print`:
354                            // > // HACK(eddyb) Rust-GPU's column numbers seem
355                            // > // off-by-one wrt what e.g. VSCode expects
356                            // > // for `:line:col` syntax, but it's hard to
357                            // > // tell from the spec and `glslang` doesn't
358                            // > // even emit column numbers at all!
359                            let col = col + 1;
360                            format!("{file}:{line}:{col}").replace('%', "%%")
361                        };
362
363                        // HACK(eddyb) this improves readability w/ very verbose Vulkan loggers.
364                        fmt += "\n";
365
366                        fmt += "[Rust ";
367
368                        // HACK(eddyb) turn "panic" into "panicked", while the
369                        // general case looks like "abort" -> "aborted".
370                        match &cx[const_str(abort_kind)] {
371                            "panic" => fmt += "panicked",
372                            verb => {
373                                fmt += verb;
374                                fmt += "en";
375                            }
376                        };
377
378                        if let Some(loc) = current_debug_src_loc.take() {
379                            fmt += " at ";
380                            fmt += &fmt_dbg_src_loc(loc);
381                        }
382
383                        fmt += "]\n ";
384                        fmt += &message_debug_printf_fmt_str.replace('\n', "\n ");
385
386                        let mut innermost = true;
387                        let mut append_call = |callsite_debug_src_loc, callee: &str| {
388                            if innermost {
389                                innermost = false;
390                                fmt += "\n      in ";
391                            } else if current_debug_src_loc.is_some() {
392                                fmt += "\n      by ";
393                            } else {
394                                // HACK(eddyb) previous call didn't have a `called at` line.
395                                fmt += "\n      called by ";
396                            }
397                            fmt += callee;
398                            if let Some(loc) = callsite_debug_src_loc {
399                                fmt += "\n        called at ";
400                                fmt += &fmt_dbg_src_loc(loc);
401                            }
402                            current_debug_src_loc = callsite_debug_src_loc;
403                        };
404                        while let Some((callsite_debug_src_loc, callee)) = call_stack.pop() {
405                            append_call(callsite_debug_src_loc, &cx[callee].replace('%', "%%"));
406                        }
407                        append_call(None, &debug_printf_context_fmt_str);
408
409                        fmt += "\n";
410
411                        let abort_inst_def = &mut func_def_body.data_insts[abort_inst];
412                        abort_inst_def.form = cx.intern(DataInstFormDef {
413                            kind: DataInstKind::SpvExtInst {
414                                ext_set: cx.intern("NonSemantic.DebugPrintf"),
415                                inst: 1,
416                            },
417                            output_type: cx[abort_inst_def.form].output_type,
418                        });
419                        abort_inst_def.inputs = [Value::Const(mk_const_str(cx.intern(fmt)))]
420                            .into_iter()
421                            .chain(message_debug_printf_args.iter().copied())
422                            .chain(debug_printf_context_inputs.iter().copied())
423                            .collect();
424
425                        // Avoid removing the instruction we just replaced.
426                        continue;
427                    }
428                    None => {}
429                }
430                block_insts.remove(abort_inst, &mut func_def_body.data_insts);
431            }
432        }
433    }
434}