1use crate::custom_insts::{self, CustomInst, CustomOp};
4use smallvec::SmallVec;
5use spirt::func_at::FuncAt;
6use spirt::{
7 Attr, AttrSet, ConstDef, ConstKind, ControlNodeKind, DataInstFormDef, DataInstKind, DeclDef,
8 EntityDefs, ExportKey, Exportee, Module, Type, TypeDef, TypeKind, TypeOrConst, Value, cfg, spv,
9};
10use std::fmt::Write as _;
11
12pub fn convert_custom_aborts_to_unstructured_returns_in_entry_points(
19 linker_options: &crate::linker::Options,
20 module: &mut Module,
21) {
22 enum Strategy {
24 Unreachable,
25 DebugPrintf { inputs: bool, backtrace: bool },
26 }
27 let abort_strategy = linker_options.abort_strategy.as_ref().map(|s| {
28 if s == "unreachable" {
29 return Strategy::Unreachable;
30 }
31 if let Some(s) = s.strip_prefix("debug-printf") {
32 let (inputs, s) = s.strip_prefix("+inputs").map_or((false, s), |s| (true, s));
33 let (backtrace, s) = s
34 .strip_prefix("+backtrace")
35 .map_or((false, s), |s| (true, s));
36 if s.is_empty() {
37 return Strategy::DebugPrintf { inputs, backtrace };
38 }
39 }
40 panic!("unknown `--abort-strategy={s}");
41 });
42
43 let cx = &module.cx();
44 let wk = &super::SpvSpecWithExtras::get().well_known;
45
46 let name_from_attrs = |attrs: AttrSet| {
48 cx[attrs].attrs.iter().find_map(|attr| match attr {
49 Attr::SpvAnnotation(spv_inst) if spv_inst.opcode == wk.OpName => Some(
50 super::diagnostics::decode_spv_lit_str_with(&spv_inst.imms, |name| {
51 name.to_string()
52 }),
53 ),
54 _ => None,
55 })
56 };
57
58 let custom_ext_inst_set = cx.intern(&custom_insts::CUSTOM_EXT_INST_SET[..]);
59
60 for (export_key, exportee) in &module.exports {
61 let (entry_point_imms, interface_global_vars, func) = match (export_key, exportee) {
62 (
63 ExportKey::SpvEntryPoint {
64 imms,
65 interface_global_vars,
66 },
67 &Exportee::Func(func),
68 ) => (imms, interface_global_vars, func),
69 _ => continue,
70 };
71
72 let func_decl = &mut module.funcs[func];
73 assert!(match &cx[func_decl.ret_type].kind {
74 TypeKind::SpvInst { spv_inst, .. } => spv_inst.opcode == wk.OpTypeVoid,
75 _ => false,
76 });
77
78 let func_def_body = match &mut func_decl.def {
79 DeclDef::Present(def) => def,
80 DeclDef::Imported(_) => continue,
81 };
82
83 let debug_printf_context_fmt_str;
84 let mut debug_printf_context_inputs = SmallVec::<[_; 4]>::new();
85 if let Some(Strategy::DebugPrintf { inputs, .. }) = abort_strategy {
86 let mut fmt = String::new();
87
88 match entry_point_imms[..] {
89 [spv::Imm::Short(em_kind, _), ref name_imms @ ..] => {
90 assert_eq!(em_kind, wk.ExecutionModel);
91 super::diagnostics::decode_spv_lit_str_with(name_imms, |name| {
92 fmt += &name.replace('%', "%%");
93 });
94 }
95 _ => unreachable!(),
96 }
97 fmt += "(";
98
99 let loaded_inputs = func_def_body
102 .at(func_def_body
103 .at_body()
104 .at_children()
105 .into_iter()
106 .next()
107 .and_then(|func_at_first_node| match func_at_first_node.def().kind {
108 ControlNodeKind::Block { insts } => Some(insts),
109 _ => None,
110 })
111 .unwrap_or_default())
112 .into_iter()
113 .filter_map(|func_at_inst| {
114 let data_inst_def = func_at_inst.def();
115 let data_inst_form_def = &cx[data_inst_def.form];
116 if let DataInstKind::SpvInst(spv_inst) = &data_inst_form_def.kind
117 && spv_inst.opcode == wk.OpLoad
118 && let Value::Const(ct) = data_inst_def.inputs[0]
119 && let ConstKind::PtrToGlobalVar(gv) = cx[ct].kind
120 && interface_global_vars.contains(&gv)
121 {
122 return Some((
123 gv,
124 data_inst_form_def.output_type.unwrap(),
125 Value::DataInstOutput(func_at_inst.position),
126 ));
127 }
128 None
129 });
130 if inputs {
131 let mut first_input = true;
132 for (gv, ty, value) in loaded_inputs {
133 let scalar_type = |ty: Type| match &cx[ty].kind {
134 TypeKind::SpvInst { spv_inst, .. } => match spv_inst.imms[..] {
135 [spv::Imm::Short(_, 32), spv::Imm::Short(_, signedness)]
136 if spv_inst.opcode == wk.OpTypeInt =>
137 {
138 Some(if signedness != 0 { "i" } else { "u" })
139 }
140 [spv::Imm::Short(_, 32)] if spv_inst.opcode == wk.OpTypeFloat => {
141 Some("f")
142 }
143 _ => None,
144 },
145 _ => None,
146 };
147 let vector_or_scalar_type = |ty: Type| {
148 let ty_def = &cx[ty];
149 match &ty_def.kind {
150 TypeKind::SpvInst {
151 spv_inst,
152 type_and_const_inputs,
153 } if spv_inst.opcode == wk.OpTypeVector => {
154 match (&type_and_const_inputs[..], &spv_inst.imms[..]) {
155 (
156 &[TypeOrConst::Type(elem)],
157 &[spv::Imm::Short(_, vlen @ 2..=4)],
158 ) => Some((scalar_type(elem)?, Some(vlen))),
159 _ => None,
160 }
161 }
162 _ => Some((scalar_type(ty)?, None)),
163 }
164 };
165 if let Some((scalar_fmt, vlen)) = vector_or_scalar_type(ty) {
166 if !first_input {
167 fmt += ", ";
168 }
169 first_input = false;
170
171 if let Some(name) = name_from_attrs(module.global_vars[gv].attrs) {
172 fmt += &name.replace('%', "%%");
173 fmt += " = ";
174 }
175 match vlen {
176 Some(vlen) => write!(fmt, "vec{vlen}(%v{vlen}{scalar_fmt})").unwrap(),
177 None => write!(fmt, "%{scalar_fmt}").unwrap(),
178 }
179 debug_printf_context_inputs.push(value);
180 }
181 }
182 }
183
184 fmt += ")";
185
186 debug_printf_context_fmt_str = fmt;
187 } else {
188 debug_printf_context_fmt_str = String::new();
189 }
190
191 let rpo_regions = func_def_body
192 .unstructured_cfg
193 .as_ref()
194 .expect("Abort->OpReturn can only be done on unstructured CFGs")
195 .rev_post_order(func_def_body);
196 for region in rpo_regions {
197 let region_def = &func_def_body.control_regions[region];
198 let control_node_def = match region_def.children.iter().last {
199 Some(last_node) => &mut func_def_body.control_nodes[last_node],
200 _ => continue,
201 };
202 let block_insts = match &mut control_node_def.kind {
203 ControlNodeKind::Block { insts } => insts,
204 _ => continue,
205 };
206
207 let terminator = &mut func_def_body
208 .unstructured_cfg
209 .as_mut()
210 .unwrap()
211 .control_inst_on_exit_from[region];
212 match terminator.kind {
213 cfg::ControlInstKind::Unreachable => {}
214 _ => continue,
215 }
216
217 let func_at_block_insts = FuncAt {
220 control_nodes: &EntityDefs::new(),
221 control_regions: &EntityDefs::new(),
222 data_insts: &func_def_body.data_insts,
223
224 position: *block_insts,
225 };
226 let block_insts_maybe_custom = func_at_block_insts.into_iter().map(|func_at_inst| {
227 let data_inst_def = func_at_inst.def();
228 (
229 func_at_inst,
230 match cx[data_inst_def.form].kind {
231 DataInstKind::SpvExtInst { ext_set, inst }
232 if ext_set == custom_ext_inst_set =>
233 {
234 Some(CustomOp::decode(inst).with_operands(&data_inst_def.inputs))
235 }
236 _ => None,
237 },
238 )
239 });
240 let custom_terminator_inst = block_insts_maybe_custom
241 .clone()
242 .rev()
243 .take_while(|(_, custom)| custom.is_some())
244 .map(|(func_at_inst, custom)| (func_at_inst, custom.unwrap()))
245 .find(|(_, custom)| !custom.op().is_debuginfo())
246 .filter(|(_, custom)| custom.op().is_terminator());
247 if let Some((
248 func_at_abort_inst,
249 CustomInst::Abort {
250 kind: abort_kind,
251 message_debug_printf,
252 },
253 )) = custom_terminator_inst
254 {
255 let abort_inst = func_at_abort_inst.position;
256 terminator.kind = cfg::ControlInstKind::ExitInvocation(
257 cfg::ExitInvocationKind::SpvInst(wk.OpReturn.into()),
258 );
259
260 match abort_strategy {
261 Some(Strategy::Unreachable) => {
262 terminator.kind = cfg::ControlInstKind::Unreachable;
263 }
264 Some(Strategy::DebugPrintf {
265 inputs: _,
266 backtrace,
267 }) => {
268 let const_kind = |v: Value| match v {
269 Value::Const(ct) => &cx[ct].kind,
270 _ => unreachable!(),
271 };
272 let const_str = |v: Value| match const_kind(v) {
273 &ConstKind::SpvStringLiteralForExtInst(s) => s,
274 _ => unreachable!(),
275 };
276 let const_u32 = |v: Value| match const_kind(v) {
277 ConstKind::SpvInst {
278 spv_inst_and_const_inputs,
279 } => {
280 let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs;
281 assert!(spv_inst.opcode == wk.OpConstant);
282 match spv_inst.imms[..] {
283 [spv::Imm::Short(_, x)] => x,
284 _ => unreachable!(),
285 }
286 }
287 _ => unreachable!(),
288 };
289 let mk_const_str = |s| {
290 cx.intern(ConstDef {
291 attrs: Default::default(),
292 ty: cx.intern(TypeDef {
293 attrs: Default::default(),
294 kind: TypeKind::SpvStringLiteralForExtInst,
295 }),
296 kind: ConstKind::SpvStringLiteralForExtInst(s),
297 })
298 };
299
300 let mut current_debug_src_loc = None;
301 let mut call_stack = SmallVec::<[_; 8]>::new();
302 let block_insts_custom = block_insts_maybe_custom
303 .filter_map(|(func_at_inst, custom)| Some((func_at_inst, custom?)));
304 for (func_at_inst, custom) in block_insts_custom {
305 if func_at_inst.position == abort_inst {
307 break;
308 }
309
310 match custom {
311 CustomInst::SetDebugSrcLoc {
312 file,
313 line_start,
314 line_end: _,
315 col_start,
316 col_end: _,
317 } => {
318 current_debug_src_loc = Some((
319 &cx[const_str(file)],
320 const_u32(line_start),
321 const_u32(col_start),
322 ));
323 }
324 CustomInst::ClearDebugSrcLoc => current_debug_src_loc = None,
325 CustomInst::PushInlinedCallFrame { callee_name } => {
326 if backtrace {
327 call_stack.push((
328 current_debug_src_loc.take(),
329 const_str(callee_name),
330 ));
331 }
332 }
333 CustomInst::PopInlinedCallFrame => {
334 if let Some((callsite_debug_src_loc, _)) = call_stack.pop() {
335 current_debug_src_loc = callsite_debug_src_loc;
336 }
337 }
338 CustomInst::Abort { .. } => {}
339 }
340 }
341
342 let mut fmt = String::new();
343
344 let (message_debug_printf_fmt_str, message_debug_printf_args) =
345 message_debug_printf
346 .split_first()
347 .map(|(&fmt_str, args)| (&cx[const_str(fmt_str)], args))
348 .unwrap_or_default();
349
350 let fmt_dbg_src_loc = |(file, line, col)| {
351 let col = col + 1;
360 format!("{file}:{line}:{col}").replace('%', "%%")
361 };
362
363 fmt += "\n";
365
366 fmt += "[Rust ";
367
368 match &cx[const_str(abort_kind)] {
371 "panic" => fmt += "panicked",
372 verb => {
373 fmt += verb;
374 fmt += "en";
375 }
376 };
377
378 if let Some(loc) = current_debug_src_loc.take() {
379 fmt += " at ";
380 fmt += &fmt_dbg_src_loc(loc);
381 }
382
383 fmt += "]\n ";
384 fmt += &message_debug_printf_fmt_str.replace('\n', "\n ");
385
386 let mut innermost = true;
387 let mut append_call = |callsite_debug_src_loc, callee: &str| {
388 if innermost {
389 innermost = false;
390 fmt += "\n in ";
391 } else if current_debug_src_loc.is_some() {
392 fmt += "\n by ";
393 } else {
394 fmt += "\n called by ";
396 }
397 fmt += callee;
398 if let Some(loc) = callsite_debug_src_loc {
399 fmt += "\n called at ";
400 fmt += &fmt_dbg_src_loc(loc);
401 }
402 current_debug_src_loc = callsite_debug_src_loc;
403 };
404 while let Some((callsite_debug_src_loc, callee)) = call_stack.pop() {
405 append_call(callsite_debug_src_loc, &cx[callee].replace('%', "%%"));
406 }
407 append_call(None, &debug_printf_context_fmt_str);
408
409 fmt += "\n";
410
411 let abort_inst_def = &mut func_def_body.data_insts[abort_inst];
412 abort_inst_def.form = cx.intern(DataInstFormDef {
413 kind: DataInstKind::SpvExtInst {
414 ext_set: cx.intern("NonSemantic.DebugPrintf"),
415 inst: 1,
416 },
417 output_type: cx[abort_inst_def.form].output_type,
418 });
419 abort_inst_def.inputs = [Value::Const(mk_const_str(cx.intern(fmt)))]
420 .into_iter()
421 .chain(message_debug_printf_args.iter().copied())
422 .chain(debug_printf_context_inputs.iter().copied())
423 .collect();
424
425 continue;
427 }
428 None => {}
429 }
430 block_insts.remove(abort_inst, &mut func_def_body.data_insts);
431 }
432 }
433 }
434}