1use super::apply_rewrite_rules;
8use super::ipo::CallGraph;
9use super::simple_passes::outgoing_edges;
10use super::{get_name, get_names};
11use crate::custom_decorations::SpanRegenerator;
12use crate::custom_insts::{self, CustomInst, CustomOp};
13use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
14use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
15use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
16use rustc_errors::ErrorGuaranteed;
17use rustc_session::Session;
18use smallvec::SmallVec;
19use std::cmp::Ordering;
20use std::mem;
21
22fn next_id(header: &mut ModuleHeader) -> Word {
24 let result = header.bound;
25 header.bound += 1;
26 result
27}
28
29pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
30 deny_recursion_in_module(sess, module)?;
32
33 let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
35
36 let custom_ext_inst_set_import = module
37 .ext_inst_imports
38 .iter()
39 .find(|inst| {
40 assert_eq!(inst.class.opcode, Op::ExtInstImport);
41 inst.operands[0].unwrap_literal_string() == &custom_insts::CUSTOM_EXT_INST_SET[..]
42 })
43 .map(|inst| inst.result_id.unwrap());
44
45 let legal_globals = LegalGlobal::gather_from_module(module);
46
47 let header = module.header.as_mut().unwrap();
48
49 #[allow(clippy::map_unwrap_or)]
51 let mut inliner = Inliner {
52 op_type_void_id: module
53 .types_global_values
54 .iter()
55 .find(|inst| inst.class.opcode == Op::TypeVoid)
56 .map(|inst| inst.result_id.unwrap())
57 .unwrap_or_else(|| {
58 let id = next_id(header);
59 let inst = Instruction::new(Op::TypeVoid, None, Some(id), vec![]);
60 module.types_global_values.push(inst);
61 id
62 }),
63
64 custom_ext_inst_set_import: custom_ext_inst_set_import.unwrap_or_else(|| {
65 let id = next_id(header);
66 let inst = Instruction::new(
67 Op::ExtInstImport,
68 None,
69 Some(id),
70 vec![Operand::LiteralString(
71 custom_insts::CUSTOM_EXT_INST_SET.to_string(),
72 )],
73 );
74 module.ext_inst_imports.push(inst);
75 id
76 }),
77
78 func_id_to_idx,
79
80 id_to_name: module
81 .debug_names
82 .iter()
83 .filter(|inst| inst.class.opcode == Op::Name)
84 .map(|inst| {
85 (
86 inst.operands[0].unwrap_id_ref(),
87 inst.operands[1].unwrap_literal_string(),
88 )
89 })
90 .collect(),
91
92 cached_op_strings: FxHashMap::default(),
93
94 header,
95 debug_string_source: &mut module.debug_string_source,
96 annotations: &mut module.annotations,
97
98 legal_globals,
99
100 functions_that_may_abort: module
104 .functions
105 .iter()
106 .filter_map(|func| {
107 let custom_ext_inst_set_import = custom_ext_inst_set_import?;
108 func.blocks
109 .iter()
110 .any(|block| match &block.instructions[..] {
111 [.., last_normal_inst, terminator_inst]
112 if last_normal_inst.class.opcode == Op::ExtInst
113 && last_normal_inst.operands[0].unwrap_id_ref()
114 == custom_ext_inst_set_import
115 && CustomOp::decode_from_ext_inst(last_normal_inst)
116 == CustomOp::Abort =>
117 {
118 assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
119 true
120 }
121
122 _ => false,
123 })
124 .then_some(func.def_id().unwrap())
125 })
126 .collect(),
127
128 inlined_dont_inlines_to_cause_and_callers: FxIndexMap::default(),
129 };
130
131 let mut functions: Vec<_> = mem::take(&mut module.functions)
132 .into_iter()
133 .map(Ok)
134 .collect();
135
136 let mut mem2reg_pointer_to_pointee = FxHashMap::default();
137 let mut mem2reg_constants = FxHashMap::default();
138 {
139 let mut u32 = None;
140 for inst in &module.types_global_values {
141 match inst.class.opcode {
142 Op::TypePointer => {
143 mem2reg_pointer_to_pointee
144 .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
145 }
146 Op::TypeInt
147 if inst.operands[0].unwrap_literal_bit32() == 32
148 && inst.operands[1].unwrap_literal_bit32() == 0 =>
149 {
150 assert!(u32.is_none());
151 u32 = Some(inst.result_id.unwrap());
152 }
153 Op::Constant if u32.is_some() && inst.result_type == u32 => {
154 let value = inst.operands[0].unwrap_literal_bit32();
155 mem2reg_constants.insert(inst.result_id.unwrap(), value);
156 }
157 _ => {}
158 }
159 }
160 }
161
162 for func_idx in call_graph.post_order() {
165 let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
166 inliner.inline_fn(&mut function, &functions);
167 fuse_trivial_branches(&mut function);
168
169 super::duplicates::DebuginfoDeduplicator {
170 custom_ext_inst_set_import,
171 }
172 .remove_duplicate_debuginfo_in_function(&mut function);
173
174 {
175 super::simple_passes::block_ordering_pass(&mut function);
176 super::mem2reg::mem2reg(
178 inliner.header,
179 &mut module.types_global_values,
180 &mem2reg_pointer_to_pointee,
181 &mem2reg_constants,
182 &mut function,
183 );
184 super::destructure_composites::destructure_composites(&mut function);
185 }
186
187 functions[func_idx] = Ok(function);
188 }
189
190 module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
191
192 let Inliner {
193 id_to_name,
194 inlined_dont_inlines_to_cause_and_callers,
195 ..
196 } = inliner;
197
198 let mut span_regen = SpanRegenerator::new(sess.source_map(), module);
199 for (callee_id, (cause, callers)) in inlined_dont_inlines_to_cause_and_callers {
200 let callee_name = get_name(&id_to_name, callee_id);
201
202 if cause == "panicking" && callee_name.starts_with("core::") {
205 continue;
206 }
207
208 let callee_span = span_regen
209 .src_loc_for_id(callee_id)
210 .and_then(|src_loc| span_regen.src_loc_to_rustc(src_loc))
211 .unwrap_or_default();
212 sess.dcx()
213 .struct_span_warn(
214 callee_span,
215 format!("`#[inline(never)]` function `{callee_name}` has been inlined"),
216 )
217 .with_note(format!("inlining was required due to {cause}"))
218 .with_note(format!(
219 "called from {}",
220 callers
221 .iter()
222 .enumerate()
223 .filter_map(|(i, &caller_id)| {
224 match i.cmp(&4) {
226 Ordering::Less => {
227 Some(format!("`{}`", get_name(&id_to_name, caller_id)))
228 }
229 Ordering::Equal => Some(format!("and {} more", callers.len() - i)),
230 Ordering::Greater => None,
231 }
232 })
233 .collect::<SmallVec<[_; 5]>>()
234 .join(", ")
235 ))
236 .emit();
237 }
238
239 Ok(())
240}
241
242fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> {
244 let func_to_index: FxHashMap<Word, usize> = module
245 .functions
246 .iter()
247 .enumerate()
248 .map(|(index, func)| (func.def_id().unwrap(), index))
249 .collect();
250 let mut discovered = vec![false; module.functions.len()];
251 let mut finished = vec![false; module.functions.len()];
252 let mut has_recursion = None;
253 for index in 0..module.functions.len() {
254 if !discovered[index] && !finished[index] {
255 visit(
256 sess,
257 module,
258 index,
259 &mut discovered,
260 &mut finished,
261 &mut has_recursion,
262 &func_to_index,
263 );
264 }
265 }
266
267 fn visit(
268 sess: &Session,
269 module: &Module,
270 current: usize,
271 discovered: &mut Vec<bool>,
272 finished: &mut Vec<bool>,
273 has_recursion: &mut Option<ErrorGuaranteed>,
274 func_to_index: &FxHashMap<Word, usize>,
275 ) {
276 discovered[current] = true;
277
278 for next in calls(&module.functions[current], func_to_index) {
279 if discovered[next] {
280 let names = get_names(module);
281 let current_name = get_name(&names, module.functions[current].def_id().unwrap());
282 let next_name = get_name(&names, module.functions[next].def_id().unwrap());
283 *has_recursion = Some(sess.dcx().err(format!(
284 "module has recursion, which is not allowed: `{current_name}` calls `{next_name}`"
285 )));
286 break;
287 }
288
289 if !finished[next] {
290 visit(
291 sess,
292 module,
293 next,
294 discovered,
295 finished,
296 has_recursion,
297 func_to_index,
298 );
299 }
300 }
301
302 discovered[current] = false;
303 finished[current] = true;
304 }
305
306 fn calls<'a>(
307 func: &'a Function,
308 func_to_index: &'a FxHashMap<Word, usize>,
309 ) -> impl Iterator<Item = usize> + 'a {
310 func.all_inst_iter()
311 .filter(|inst| inst.class.opcode == Op::FunctionCall)
312 .map(move |inst| {
313 *func_to_index
314 .get(&inst.operands[0].id_ref_any().unwrap())
315 .unwrap()
316 })
317 }
318
319 match has_recursion {
320 Some(err) => Err(err),
321 None => Ok(()),
322 }
323}
324
325enum LegalGlobal {
333 TypePointer(StorageClass),
334 TypeNonPointer,
335 Const,
336 Variable,
337}
338
339impl LegalGlobal {
340 fn gather_from_module(module: &Module) -> FxHashMap<Word, Self> {
341 let mut legal_globals = FxHashMap::<_, Self>::default();
342 for inst in &module.types_global_values {
343 let global = match inst.class.opcode {
344 Op::TypePointer => Self::TypePointer(inst.operands[0].unwrap_storage_class()),
345 Op::Variable => Self::Variable,
346 op if rspirv::grammar::reflect::is_type(op) => Self::TypeNonPointer,
347 op if rspirv::grammar::reflect::is_constant(op) => Self::Const,
348
349 _ => continue,
351 };
352 let legal_result_type = match inst.result_type {
353 Some(result_type_id) => matches!(
354 (&global, legal_globals.get(&result_type_id)),
355 (Self::Variable, Some(Self::TypePointer(_)))
356 | (Self::Const, Some(Self::TypeNonPointer))
357 ),
358 None => matches!(global, Self::TypePointer(_) | Self::TypeNonPointer),
359 };
360 let legal_operands = inst.operands.iter().all(|operand| match operand {
361 Operand::IdRef(id) => matches!(
362 legal_globals.get(id),
363 Some(Self::TypeNonPointer | Self::Const)
364 ),
365
366 _ => operand.id_ref_any().is_none(),
368 });
369 if legal_result_type && legal_operands {
370 legal_globals.insert(inst.result_id.unwrap(), global);
371 }
372 }
373 legal_globals
374 }
375
376 fn legal_as_fn_param_ty(&self) -> bool {
377 match *self {
378 Self::TypePointer(storage_class) => matches!(
379 storage_class,
380 StorageClass::UniformConstant
381 | StorageClass::Function
382 | StorageClass::Private
383 | StorageClass::Workgroup
384 | StorageClass::AtomicCounter
385 ),
386 Self::TypeNonPointer => true,
387
388 Self::Const | Self::Variable => false,
390 }
391 }
392
393 fn legal_as_fn_ret_ty(&self) -> bool {
394 #[allow(clippy::match_same_arms)]
395 match *self {
396 Self::TypePointer(_) => false,
397 Self::TypeNonPointer => true,
398
399 Self::Const | Self::Variable => false,
401 }
402 }
403}
404
405#[derive(Copy, Clone)]
407struct CallSite<'a> {
408 caller: &'a Function,
409 call_inst: &'a Instruction,
410}
411
412fn has_dont_inline(function: &Function) -> bool {
413 let def = function.def.as_ref().unwrap();
414 let control = def.operands[0].unwrap_function_control();
415 control.contains(FunctionControl::DONT_INLINE)
416}
417
418#[derive(Copy, Clone, PartialEq, Eq)]
420struct MustInlineToLegalize(&'static str);
421
422fn should_inline(
431 legal_globals: &FxHashMap<Word, LegalGlobal>,
432 functions_that_may_abort: &FxHashSet<Word>,
433 callee: &Function,
434 call_site: CallSite<'_>,
435) -> Result<bool, MustInlineToLegalize> {
436 let callee_def = callee.def.as_ref().unwrap();
437 let callee_control = callee_def.operands[0].unwrap_function_control();
438
439 if functions_that_may_abort.contains(&callee.def_id().unwrap()) {
440 return Err(MustInlineToLegalize("panicking"));
441 }
442
443 let ret_ty = legal_globals
444 .get(&callee_def.result_type.unwrap())
445 .ok_or(MustInlineToLegalize("illegal return type"))?;
446 if !ret_ty.legal_as_fn_ret_ty() {
447 return Err(MustInlineToLegalize("illegal (pointer) return type"));
448 }
449
450 for (i, param) in callee.parameters.iter().enumerate() {
451 let param_ty = legal_globals
452 .get(param.result_type.as_ref().unwrap())
453 .ok_or(MustInlineToLegalize("illegal parameter type"))?;
454 if !param_ty.legal_as_fn_param_ty() {
455 return Err(MustInlineToLegalize("illegal (pointer) parameter type"));
456 }
457
458 if let LegalGlobal::TypePointer(_) = param_ty {
464 let ptr_arg = call_site.call_inst.operands[i + 1].unwrap_id_ref();
465 match legal_globals.get(&ptr_arg) {
466 Some(LegalGlobal::Variable) => {}
467
468 Some(_) => return Err(MustInlineToLegalize("illegal (pointer) argument")),
470
471 None => {
472 let mut caller_param_and_var_ids = call_site
473 .caller
474 .parameters
475 .iter()
476 .chain(
477 call_site.caller.blocks[0]
478 .instructions
479 .iter()
480 .filter(|caller_inst| {
481 let may_be_debuginfo = matches!(
485 caller_inst.class.opcode,
486 Op::Line | Op::NoLine | Op::ExtInst
487 );
488 !may_be_debuginfo
489 })
490 .take_while(|caller_inst| caller_inst.class.opcode == Op::Variable),
491 )
492 .map(|caller_inst| caller_inst.result_id.unwrap());
493
494 if !caller_param_and_var_ids.any(|id| ptr_arg == id) {
495 return Err(MustInlineToLegalize("illegal (pointer) argument"));
496 }
497 }
498 }
499 }
500 }
501
502 Ok(callee_control.contains(FunctionControl::INLINE))
503}
504
505#[derive(Debug)]
508struct FuncIsBeingInlined;
509
510struct Inliner<'a, 'b> {
517 custom_ext_inst_set_import: Word,
520
521 op_type_void_id: Word,
522
523 func_id_to_idx: FxHashMap<Word, usize>,
525
526 id_to_name: FxHashMap<Word, &'a str>,
529
530 cached_op_strings: FxHashMap<&'a str, Word>,
536
537 header: &'b mut ModuleHeader,
538 debug_string_source: &'b mut Vec<Instruction>,
539 annotations: &'b mut Vec<Instruction>,
540
541 legal_globals: FxHashMap<Word, LegalGlobal>,
542 functions_that_may_abort: FxHashSet<Word>,
543 inlined_dont_inlines_to_cause_and_callers: FxIndexMap<Word, (&'static str, FxIndexSet<Word>)>,
544 }
546
547impl Inliner<'_, '_> {
548 fn id(&mut self) -> Word {
549 next_id(self.header)
550 }
551
552 fn apply_rewrite_for_decorations(&mut self, rewrite_rules: &FxHashMap<Word, Word>) {
554 for annotation_idx in 0..self.annotations.len() {
557 let inst = &self.annotations[annotation_idx];
558 if let [Operand::IdRef(target), ..] = inst.operands[..]
559 && let Some(&rewritten_target) = rewrite_rules.get(&target)
560 {
561 let mut cloned_inst = inst.clone();
563 cloned_inst.operands[0] = Operand::IdRef(rewritten_target);
564 self.annotations.push(cloned_inst);
565 }
566 }
567 }
568
569 fn inline_fn(
570 &mut self,
571 function: &mut Function,
572 functions: &[Result<Function, FuncIsBeingInlined>],
573 ) {
574 let mut block_idx = 0;
575 while block_idx < function.blocks.len() {
576 if !self.inline_block(function, block_idx, functions) {
580 block_idx += 1;
582 }
583 }
584 }
585
586 fn inline_block(
587 &mut self,
588 caller: &mut Function,
589 block_idx: usize,
590 functions: &[Result<Function, FuncIsBeingInlined>],
591 ) -> bool {
592 let call = caller.blocks[block_idx]
594 .instructions
595 .iter()
596 .enumerate()
597 .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
598 .map(|(index, inst)| {
599 (
600 index,
601 inst,
602 functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
603 .as_ref()
604 .unwrap(),
605 )
606 })
607 .find(|(_, inst, f)| {
608 let call_site = CallSite {
609 caller,
610 call_inst: inst,
611 };
612 match should_inline(
613 &self.legal_globals,
614 &self.functions_that_may_abort,
615 f,
616 call_site,
617 ) {
618 Ok(inline) => inline,
619 Err(MustInlineToLegalize(cause)) => {
620 if has_dont_inline(f) {
621 self.inlined_dont_inlines_to_cause_and_callers
622 .entry(f.def_id().unwrap())
623 .or_insert_with(|| (cause, Default::default()))
624 .1
625 .insert(caller.def_id().unwrap());
626 }
627 true
628 }
629 }
630 });
631 let (call_index, call_inst, callee) = match call {
632 None => return false,
633 Some(call) => call,
634 };
635
636 if self
638 .functions_that_may_abort
639 .contains(&callee.def_id().unwrap())
640 {
641 self.functions_that_may_abort
642 .insert(caller.def_id().unwrap());
643 }
644
645 let mut maybe_call_result_phi = {
646 let ty = call_inst.result_type.unwrap();
647 if ty == self.op_type_void_id {
648 None
649 } else {
650 Some(Instruction::new(
651 Op::Phi,
652 Some(ty),
653 Some(call_inst.result_id.unwrap()),
654 vec![],
655 ))
656 }
657 };
658
659 let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
661 let call_debug_src_loc_inst = caller.blocks[block_idx].instructions[..call_index]
662 .iter()
663 .rev()
664 .find_map(|inst| {
665 Some(match inst.class.opcode {
666 Op::Line => Some(inst),
667 Op::NoLine => None,
668 Op::ExtInst
669 if inst.operands[0].unwrap_id_ref() == custom_ext_inst_set_import =>
670 {
671 match CustomOp::decode_from_ext_inst(inst) {
672 CustomOp::SetDebugSrcLoc => Some(inst),
673 CustomOp::ClearDebugSrcLoc => None,
674 _ => return None,
675 }
676 }
677 _ => return None,
678 })
679 })
680 .flatten();
681
682 let call_arguments = call_inst
684 .operands
685 .iter()
686 .skip(1)
687 .map(|op| op.id_ref_any().unwrap());
688 let callee_parameters = callee.parameters.iter().map(|inst| {
689 assert!(inst.class.opcode == Op::FunctionParameter);
690 inst.result_id.unwrap()
691 });
692 let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
693
694 let return_jump = self.id();
695 let mut inlined_callee_blocks = self.get_inlined_blocks(
697 callee,
698 call_debug_src_loc_inst,
699 maybe_call_result_phi.as_mut(),
700 return_jump,
701 );
702 self.add_clone_id_rules(&mut rewrite_rules, &inlined_callee_blocks);
705 apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks);
706 self.apply_rewrite_for_decorations(&rewrite_rules);
707
708 if let Some(call_result_phi) = &mut maybe_call_result_phi {
709 for op in &mut call_result_phi.operands {
712 if let Some(id) = op.id_ref_any_mut()
713 && let Some(&rewrite) = rewrite_rules.get(id)
714 {
715 *id = rewrite;
716 }
717 }
718
719 if let [returned_value, _return_block] = &call_result_phi.operands[..] {
723 let call_result_id = call_result_phi.result_id.unwrap();
724 let returned_value_id = returned_value.unwrap_id_ref();
725
726 maybe_call_result_phi = None;
727
728 let reaching_insts = {
731 let (pre_call_blocks, call_and_post_call_blocks) =
732 caller.blocks.split_at_mut(block_idx);
733 (pre_call_blocks.iter_mut().flat_map(|block| {
734 block
735 .instructions
736 .iter_mut()
737 .take_while(|inst| inst.class.opcode == Op::Phi)
738 }))
739 .chain(
740 call_and_post_call_blocks
741 .iter_mut()
742 .flat_map(|block| &mut block.instructions),
743 )
744 };
745 for reaching_inst in reaching_insts {
746 for op in &mut reaching_inst.operands {
747 if let Some(id) = op.id_ref_any_mut()
748 && *id == call_result_id
749 {
750 *id = returned_value_id;
751 }
752 }
753 }
754 }
755 }
756
757 let pre_call_block_idx = block_idx;
759 #[expect(unused)]
760 let block_idx: usize; let mut post_call_block_insts = caller.blocks[pre_call_block_idx]
762 .instructions
763 .split_off(call_index + 1);
764
765 let call = caller.blocks[pre_call_block_idx]
767 .instructions
768 .pop()
769 .unwrap();
770 assert!(call.class.opcode == Op::FunctionCall);
771
772 let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..);
774 let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len();
775 caller.blocks.splice(
776 (pre_call_block_idx + 1)..(pre_call_block_idx + 1),
777 non_entry_inlined_callee_blocks,
778 );
779
780 if let Some(call_result_phi) = maybe_call_result_phi {
781 post_call_block_insts.insert(0, call_result_phi);
783 }
784
785 {
787 let post_call_block_idx = pre_call_block_idx + num_non_entry_inlined_callee_blocks + 1;
788 let post_call_block = Block {
789 label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])),
790 instructions: post_call_block_insts,
791 };
792 caller.blocks.insert(post_call_block_idx, post_call_block);
793
794 rewrite_phi_sources(
797 caller.blocks[pre_call_block_idx].label_id().unwrap(),
798 &mut caller.blocks,
799 post_call_block_idx,
800 );
801 }
802
803 {
806 let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
809 let mut inst = inst.clone();
810 if let Some(id) = &mut inst.result_id {
811 *id = this.id();
812 }
813 inst
814 };
815
816 let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
817 Instruction::new(
818 Op::ExtInst,
819 Some(this.op_type_void_id),
820 Some(this.id()),
821 [
822 Operand::IdRef(this.custom_ext_inst_set_import),
823 Operand::LiteralExtInstInteger(inst.op() as u32),
824 ]
825 .into_iter()
826 .chain(inst.into_operands())
827 .collect(),
828 )
829 };
830
831 let mut steal_vars = |insts: &mut Vec<Instruction>| {
835 let mut enclosing_inlined_frames = SmallVec::<[_; 8]>::new();
839 let mut current_debug_src_loc_inst = None;
840 let mut vars_and_debuginfo_range = 0..0;
841 while vars_and_debuginfo_range.end < insts.len() {
842 let inst = &insts[vars_and_debuginfo_range.end];
843 match inst.class.opcode {
844 Op::Line => current_debug_src_loc_inst = Some(inst),
845 Op::NoLine => current_debug_src_loc_inst = None,
846 Op::ExtInst
847 if inst.operands[0].unwrap_id_ref()
848 == self.custom_ext_inst_set_import =>
849 {
850 match CustomOp::decode_from_ext_inst(inst) {
851 CustomOp::SetDebugSrcLoc => current_debug_src_loc_inst = Some(inst),
852 CustomOp::ClearDebugSrcLoc => current_debug_src_loc_inst = None,
853 CustomOp::PushInlinedCallFrame => {
854 enclosing_inlined_frames
855 .push((current_debug_src_loc_inst.take(), inst));
856 }
857 CustomOp::PopInlinedCallFrame => {
858 if let Some((callsite_debug_src_loc_inst, _)) =
859 enclosing_inlined_frames.pop()
860 {
861 current_debug_src_loc_inst = callsite_debug_src_loc_inst;
862 }
863 }
864 CustomOp::Abort => break,
865 }
866 }
867 Op::Variable => {}
868 _ => break,
869 }
870 vars_and_debuginfo_range.end += 1;
871 }
872
873 let all_pops_after_vars: SmallVec<[_; 8]> = enclosing_inlined_frames
885 .iter()
886 .map(|_| custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame))
887 .collect();
888 let all_repushes_before_non_vars: SmallVec<[_; 8]> =
889 (enclosing_inlined_frames.into_iter().flat_map(
890 |(callsite_debug_src_loc_inst, push_inlined_call_frame_inst)| {
891 (callsite_debug_src_loc_inst.into_iter())
892 .chain([push_inlined_call_frame_inst])
893 },
894 ))
895 .chain(current_debug_src_loc_inst)
896 .map(|inst| instantiate_debuginfo(self, inst))
897 .collect();
898
899 let vars_and_debuginfo =
900 insts.splice(vars_and_debuginfo_range, all_repushes_before_non_vars);
901 let repaired_vars_and_debuginfo = vars_and_debuginfo.chain(all_pops_after_vars);
902
903 repaired_vars_and_debuginfo.collect::<SmallVec<[_; 8]>>()
906 };
907
908 let [mut inlined_callee_entry_block]: [_; 1] =
909 inlined_callee_blocks.try_into().unwrap();
910
911 let callee_vars_and_debuginfo =
913 steal_vars(&mut inlined_callee_entry_block.instructions);
914 self.insert_opvariables(&mut caller.blocks[0], callee_vars_and_debuginfo);
915
916 caller.blocks[pre_call_block_idx]
917 .instructions
918 .append(&mut inlined_callee_entry_block.instructions);
919
920 rewrite_phi_sources(
924 inlined_callee_entry_block.label_id().unwrap(),
925 &mut caller.blocks,
926 pre_call_block_idx,
927 );
928 }
929
930 true
931 }
932
933 fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap<Word, Word>, blocks: &[Block]) {
934 for block in blocks {
935 for inst in block.label.iter().chain(&block.instructions) {
936 if let Some(result_id) = inst.result_id {
937 let new_id = self.id();
938 let old = rewrite_rules.insert(result_id, new_id);
939 assert!(old.is_none());
940 }
941 }
942 }
943 }
944
945 fn get_inlined_blocks(
946 &mut self,
947 callee: &Function,
948 call_debug_src_loc_inst: Option<&Instruction>,
949 mut maybe_call_result_phi: Option<&mut Instruction>,
950 return_jump: Word,
951 ) -> Vec<Block> {
952 let Self {
953 custom_ext_inst_set_import,
954 op_type_void_id,
955 ..
956 } = *self;
957
958 let callee_name = self
965 .id_to_name
966 .get(&callee.def_id().unwrap())
967 .copied()
968 .unwrap_or("");
969 let callee_name_id = *self
970 .cached_op_strings
971 .entry(callee_name)
972 .or_insert_with(|| {
973 let id = next_id(self.header);
974 self.debug_string_source.push(Instruction::new(
975 Op::String,
976 None,
977 Some(id),
978 vec![Operand::LiteralString(callee_name.to_string())],
979 ));
980 id
981 });
982 let mut mk_debuginfo_prefix_and_suffix = || {
983 let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
986 let mut inst = inst.clone();
987 if let Some(id) = &mut inst.result_id {
988 *id = this.id();
989 }
990 inst
991 };
992 let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
993 Instruction::new(
994 Op::ExtInst,
995 Some(op_type_void_id),
996 Some(this.id()),
997 [
998 Operand::IdRef(custom_ext_inst_set_import),
999 Operand::LiteralExtInstInteger(inst.op() as u32),
1000 ]
1001 .into_iter()
1002 .chain(inst.into_operands())
1003 .collect(),
1004 )
1005 };
1006
1007 (
1008 (call_debug_src_loc_inst.map(|inst| instantiate_debuginfo(self, inst)))
1009 .into_iter()
1010 .chain([custom_inst_to_inst(
1011 self,
1012 CustomInst::PushInlinedCallFrame {
1013 callee_name: Operand::IdRef(callee_name_id),
1014 },
1015 )]),
1016 [custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame)],
1017 )
1018 };
1019
1020 let mut blocks = callee.blocks.clone();
1021 for block in &mut blocks {
1022 let mut terminator = block.instructions.pop().unwrap();
1023
1024 while let Some(last) = block.instructions.last() {
1026 let can_remove = match last.class.opcode {
1027 Op::Line | Op::NoLine => true,
1028 Op::ExtInst => {
1029 last.operands[0].unwrap_id_ref() == custom_ext_inst_set_import
1030 && matches!(
1031 CustomOp::decode_from_ext_inst(last),
1032 CustomOp::SetDebugSrcLoc | CustomOp::ClearDebugSrcLoc
1033 )
1034 }
1035 _ => false,
1036 };
1037 if can_remove {
1038 block.instructions.pop();
1039 } else {
1040 break;
1041 }
1042 }
1043
1044 if let Op::Return | Op::ReturnValue = terminator.class.opcode {
1045 if Op::ReturnValue == terminator.class.opcode {
1046 let return_value = terminator.operands[0].id_ref_any().unwrap();
1047 let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap();
1048 call_result_phi.operands.extend([
1049 Operand::IdRef(return_value),
1050 Operand::IdRef(block.label_id().unwrap()),
1051 ]);
1052 } else {
1053 assert!(maybe_call_result_phi.is_none());
1054 }
1055 terminator =
1056 Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
1057 }
1058
1059 let num_phis = block
1060 .instructions
1061 .iter()
1062 .take_while(|inst| inst.class.opcode == Op::Phi)
1063 .count();
1064
1065 if block.instructions.len() > num_phis {
1067 let (debuginfo_prefix, debuginfo_suffix) = mk_debuginfo_prefix_and_suffix();
1068 block
1071 .instructions
1072 .splice(num_phis..num_phis, debuginfo_prefix);
1073 block.instructions.extend(debuginfo_suffix);
1076 }
1077
1078 block.instructions.push(terminator);
1079 }
1080
1081 blocks
1082 }
1083
1084 fn insert_opvariables(&self, block: &mut Block, insts: impl IntoIterator<Item = Instruction>) {
1085 let mut inlined_frames_depth = 0usize;
1089 let mut outermost_has_debug_src_loc = false;
1090 let mut last_debugless_var_insertion_point_candidate = None;
1091 for (i, inst) in block.instructions.iter().enumerate() {
1092 last_debugless_var_insertion_point_candidate =
1093 (inlined_frames_depth == 0 && !outermost_has_debug_src_loc).then_some(i);
1094
1095 let changed_has_debug_src_loc = match inst.class.opcode {
1096 Op::Line => true,
1097 Op::NoLine => false,
1098 Op::ExtInst
1099 if inst.operands[0].unwrap_id_ref() == self.custom_ext_inst_set_import =>
1100 {
1101 match CustomOp::decode_from_ext_inst(inst) {
1102 CustomOp::SetDebugSrcLoc => true,
1103 CustomOp::ClearDebugSrcLoc => false,
1104 CustomOp::PushInlinedCallFrame => {
1105 inlined_frames_depth += 1;
1106 continue;
1107 }
1108 CustomOp::PopInlinedCallFrame => {
1109 inlined_frames_depth = inlined_frames_depth.saturating_sub(1);
1110 continue;
1111 }
1112 CustomOp::Abort => break,
1113 }
1114 }
1115 Op::Variable => continue,
1116 _ => break,
1117 };
1118
1119 if inlined_frames_depth == 0 {
1120 outermost_has_debug_src_loc = changed_has_debug_src_loc;
1121 }
1122 }
1123
1124 let i = last_debugless_var_insertion_point_candidate.unwrap_or(0);
1127 block.instructions.splice(i..i, insts);
1128 }
1129}
1130
1131fn fuse_trivial_branches(function: &mut Function) {
1132 let all_preds = compute_preds(&function.blocks);
1133 'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
1134 let any_phis = function.blocks[dest_block]
1136 .instructions
1137 .iter()
1138 .filter(|inst| {
1139 !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1141 })
1142 .take_while(|inst| inst.class.opcode == Op::Phi)
1143 .next()
1144 .is_some();
1145 if any_phis {
1146 continue;
1147 }
1148
1149 let pred = loop {
1152 if preds.len() != 1 || preds[0] == dest_block {
1153 continue 'outer;
1154 }
1155 let pred = preds[0];
1156 if !function.blocks[pred].instructions.is_empty() {
1157 break pred;
1158 }
1159 preds = &all_preds[pred];
1160 };
1161 let pred_insts = &function.blocks[pred].instructions;
1162 if pred_insts.last().unwrap().class.opcode == Op::Branch {
1163 let mut dest_insts = mem::take(&mut function.blocks[dest_block].instructions);
1164 let pred_insts = &mut function.blocks[pred].instructions;
1165 pred_insts.pop(); pred_insts.append(&mut dest_insts);
1167
1168 rewrite_phi_sources(
1171 function.blocks[dest_block].label_id().unwrap(),
1172 &mut function.blocks,
1173 pred,
1174 );
1175 }
1176 }
1177 function.blocks.retain(|b| !b.instructions.is_empty());
1178}
1179
1180fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
1181 let mut result = vec![vec![]; blocks.len()];
1182 for (source_idx, source) in blocks.iter().enumerate() {
1183 for dest_id in outgoing_edges(source) {
1184 let dest_idx = blocks
1185 .iter()
1186 .position(|b| b.label_id().unwrap() == dest_id)
1187 .unwrap();
1188 result[dest_idx].push(source_idx);
1189 }
1190 }
1191 result
1192}
1193
1194fn rewrite_phi_sources(original_label_id: Word, blocks: &mut [Block], new_block_idx: usize) {
1197 let new_label_id = blocks[new_block_idx].label_id().unwrap();
1198
1199 let target_ids: SmallVec<[_; 4]> = outgoing_edges(&blocks[new_block_idx]).collect();
1201
1202 for target_id in target_ids {
1203 let target_block = blocks
1204 .iter_mut()
1205 .find(|b| b.label_id().unwrap() == target_id)
1206 .unwrap();
1207 let phis = target_block
1208 .instructions
1209 .iter_mut()
1210 .filter(|inst| {
1211 !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1213 })
1214 .take_while(|inst| inst.class.opcode == Op::Phi);
1215 for phi in phis {
1216 for value_and_source_id in phi.operands.chunks_mut(2) {
1217 let source_id = value_and_source_id[1].id_ref_any_mut().unwrap();
1218 if *source_id == original_label_id {
1219 *source_id = new_label_id;
1220 break;
1221 }
1222 }
1223 }
1224 }
1225}