1use super::apply_rewrite_rules;
8use super::ipo::CallGraph;
9use super::simple_passes::outgoing_edges;
10use super::{get_name, get_names};
11use crate::custom_decorations::SpanRegenerator;
12use crate::custom_insts::{self, CustomInst, CustomOp};
13use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
14use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
15use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
16use rustc_errors::ErrorGuaranteed;
17use rustc_session::Session;
18use smallvec::SmallVec;
19use std::cmp::Ordering;
20use std::mem;
21
22fn next_id(header: &mut ModuleHeader) -> Word {
24 let result = header.bound;
25 header.bound += 1;
26 result
27}
28
29pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
30 deny_recursion_in_module(sess, module)?;
32
33 let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
35
36 let custom_ext_inst_set_import = module
37 .ext_inst_imports
38 .iter()
39 .find(|inst| {
40 assert_eq!(inst.class.opcode, Op::ExtInstImport);
41 inst.operands[0].unwrap_literal_string() == &custom_insts::CUSTOM_EXT_INST_SET[..]
42 })
43 .map(|inst| inst.result_id.unwrap());
44
45 let legal_globals = LegalGlobal::gather_from_module(module);
46
47 let header = module.header.as_mut().unwrap();
48
49 #[allow(clippy::map_unwrap_or)]
51 let mut inliner = Inliner {
52 op_type_void_id: module
53 .types_global_values
54 .iter()
55 .find(|inst| inst.class.opcode == Op::TypeVoid)
56 .map(|inst| inst.result_id.unwrap())
57 .unwrap_or_else(|| {
58 let id = next_id(header);
59 let inst = Instruction::new(Op::TypeVoid, None, Some(id), vec![]);
60 module.types_global_values.push(inst);
61 id
62 }),
63
64 custom_ext_inst_set_import: custom_ext_inst_set_import.unwrap_or_else(|| {
65 let id = next_id(header);
66 let inst = Instruction::new(
67 Op::ExtInstImport,
68 None,
69 Some(id),
70 vec![Operand::LiteralString(
71 custom_insts::CUSTOM_EXT_INST_SET.to_string(),
72 )],
73 );
74 module.ext_inst_imports.push(inst);
75 id
76 }),
77
78 func_id_to_idx,
79
80 id_to_name: module
81 .debug_names
82 .iter()
83 .filter(|inst| inst.class.opcode == Op::Name)
84 .map(|inst| {
85 (
86 inst.operands[0].unwrap_id_ref(),
87 inst.operands[1].unwrap_literal_string(),
88 )
89 })
90 .collect(),
91
92 cached_op_strings: FxHashMap::default(),
93
94 header,
95 debug_string_source: &mut module.debug_string_source,
96 annotations: &mut module.annotations,
97
98 legal_globals,
99
100 functions_that_may_abort: module
104 .functions
105 .iter()
106 .filter_map(|func| {
107 let custom_ext_inst_set_import = custom_ext_inst_set_import?;
108 func.blocks
109 .iter()
110 .any(|block| match &block.instructions[..] {
111 [.., last_normal_inst, terminator_inst]
112 if last_normal_inst.class.opcode == Op::ExtInst
113 && last_normal_inst.operands[0].unwrap_id_ref()
114 == custom_ext_inst_set_import
115 && CustomOp::decode_from_ext_inst(last_normal_inst)
116 == CustomOp::Abort =>
117 {
118 assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
119 true
120 }
121
122 _ => false,
123 })
124 .then_some(func.def_id().unwrap())
125 })
126 .collect(),
127
128 inlined_dont_inlines_to_cause_and_callers: FxIndexMap::default(),
129 };
130
131 let mut functions: Vec<_> = mem::take(&mut module.functions)
132 .into_iter()
133 .map(Ok)
134 .collect();
135
136 for func_idx in call_graph.post_order() {
139 let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
140 inliner.inline_fn(&mut function, &functions);
141 fuse_trivial_branches(&mut function);
142 functions[func_idx] = Ok(function);
143 }
144
145 module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
146
147 let Inliner {
148 id_to_name,
149 inlined_dont_inlines_to_cause_and_callers,
150 ..
151 } = inliner;
152
153 let mut span_regen = SpanRegenerator::new(sess.source_map(), module);
154 for (callee_id, (cause, callers)) in inlined_dont_inlines_to_cause_and_callers {
155 let callee_name = get_name(&id_to_name, callee_id);
156
157 if cause == "panicking" && callee_name.starts_with("core::") {
160 continue;
161 }
162
163 let callee_span = span_regen
164 .src_loc_for_id(callee_id)
165 .and_then(|src_loc| span_regen.src_loc_to_rustc(src_loc))
166 .unwrap_or_default();
167 sess.dcx()
168 .struct_span_warn(
169 callee_span,
170 format!("`#[inline(never)]` function `{callee_name}` has been inlined"),
171 )
172 .with_note(format!("inlining was required due to {cause}"))
173 .with_note(format!(
174 "called from {}",
175 callers
176 .iter()
177 .enumerate()
178 .filter_map(|(i, &caller_id)| {
179 match i.cmp(&4) {
181 Ordering::Less => {
182 Some(format!("`{}`", get_name(&id_to_name, caller_id)))
183 }
184 Ordering::Equal => Some(format!("and {} more", callers.len() - i)),
185 Ordering::Greater => None,
186 }
187 })
188 .collect::<SmallVec<[_; 5]>>()
189 .join(", ")
190 ))
191 .emit();
192 }
193
194 Ok(())
195}
196
197fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> {
199 let func_to_index: FxHashMap<Word, usize> = module
200 .functions
201 .iter()
202 .enumerate()
203 .map(|(index, func)| (func.def_id().unwrap(), index))
204 .collect();
205 let mut discovered = vec![false; module.functions.len()];
206 let mut finished = vec![false; module.functions.len()];
207 let mut has_recursion = None;
208 for index in 0..module.functions.len() {
209 if !discovered[index] && !finished[index] {
210 visit(
211 sess,
212 module,
213 index,
214 &mut discovered,
215 &mut finished,
216 &mut has_recursion,
217 &func_to_index,
218 );
219 }
220 }
221
222 fn visit(
223 sess: &Session,
224 module: &Module,
225 current: usize,
226 discovered: &mut Vec<bool>,
227 finished: &mut Vec<bool>,
228 has_recursion: &mut Option<ErrorGuaranteed>,
229 func_to_index: &FxHashMap<Word, usize>,
230 ) {
231 discovered[current] = true;
232
233 for next in calls(&module.functions[current], func_to_index) {
234 if discovered[next] {
235 let names = get_names(module);
236 let current_name = get_name(&names, module.functions[current].def_id().unwrap());
237 let next_name = get_name(&names, module.functions[next].def_id().unwrap());
238 *has_recursion = Some(sess.dcx().err(format!(
239 "module has recursion, which is not allowed: `{current_name}` calls `{next_name}`"
240 )));
241 break;
242 }
243
244 if !finished[next] {
245 visit(
246 sess,
247 module,
248 next,
249 discovered,
250 finished,
251 has_recursion,
252 func_to_index,
253 );
254 }
255 }
256
257 discovered[current] = false;
258 finished[current] = true;
259 }
260
261 fn calls<'a>(
262 func: &'a Function,
263 func_to_index: &'a FxHashMap<Word, usize>,
264 ) -> impl Iterator<Item = usize> + 'a {
265 func.all_inst_iter()
266 .filter(|inst| inst.class.opcode == Op::FunctionCall)
267 .map(move |inst| {
268 *func_to_index
269 .get(&inst.operands[0].id_ref_any().unwrap())
270 .unwrap()
271 })
272 }
273
274 match has_recursion {
275 Some(err) => Err(err),
276 None => Ok(()),
277 }
278}
279
280enum LegalGlobal {
288 TypePointer(StorageClass),
289 TypeNonPointer,
290 Const,
291 Variable,
292}
293
294impl LegalGlobal {
295 fn gather_from_module(module: &Module) -> FxHashMap<Word, Self> {
296 let mut legal_globals = FxHashMap::<_, Self>::default();
297 for inst in &module.types_global_values {
298 let global = match inst.class.opcode {
299 Op::TypePointer => Self::TypePointer(inst.operands[0].unwrap_storage_class()),
300 Op::Variable => Self::Variable,
301 op if rspirv::grammar::reflect::is_type(op) => Self::TypeNonPointer,
302 op if rspirv::grammar::reflect::is_constant(op) => Self::Const,
303
304 _ => continue,
306 };
307 let legal_result_type = match inst.result_type {
308 Some(result_type_id) => matches!(
309 (&global, legal_globals.get(&result_type_id)),
310 (Self::Variable, Some(Self::TypePointer(_)))
311 | (Self::Const, Some(Self::TypeNonPointer))
312 ),
313 None => matches!(global, Self::TypePointer(_) | Self::TypeNonPointer),
314 };
315 let legal_operands = inst.operands.iter().all(|operand| match operand {
316 Operand::IdRef(id) => matches!(
317 legal_globals.get(id),
318 Some(Self::TypeNonPointer | Self::Const)
319 ),
320
321 _ => operand.id_ref_any().is_none(),
323 });
324 if legal_result_type && legal_operands {
325 legal_globals.insert(inst.result_id.unwrap(), global);
326 }
327 }
328 legal_globals
329 }
330
331 fn legal_as_fn_param_ty(&self) -> bool {
332 match *self {
333 Self::TypePointer(storage_class) => matches!(
334 storage_class,
335 StorageClass::UniformConstant
336 | StorageClass::Function
337 | StorageClass::Private
338 | StorageClass::Workgroup
339 | StorageClass::AtomicCounter
340 ),
341 Self::TypeNonPointer => true,
342
343 Self::Const | Self::Variable => false,
345 }
346 }
347
348 fn legal_as_fn_ret_ty(&self) -> bool {
349 #[allow(clippy::match_same_arms)]
350 match *self {
351 Self::TypePointer(_) => false,
352 Self::TypeNonPointer => true,
353
354 Self::Const | Self::Variable => false,
356 }
357 }
358}
359
360#[derive(Copy, Clone)]
362struct CallSite<'a> {
363 caller: &'a Function,
364 call_inst: &'a Instruction,
365}
366
367fn has_dont_inline(function: &Function) -> bool {
368 let def = function.def.as_ref().unwrap();
369 let control = def.operands[0].unwrap_function_control();
370 control.contains(FunctionControl::DONT_INLINE)
371}
372
373#[derive(Copy, Clone, PartialEq, Eq)]
375struct MustInlineToLegalize(&'static str);
376
377fn should_inline(
386 legal_globals: &FxHashMap<Word, LegalGlobal>,
387 functions_that_may_abort: &FxHashSet<Word>,
388 callee: &Function,
389 call_site: CallSite<'_>,
390) -> Result<bool, MustInlineToLegalize> {
391 let callee_def = callee.def.as_ref().unwrap();
392 let callee_control = callee_def.operands[0].unwrap_function_control();
393
394 if functions_that_may_abort.contains(&callee.def_id().unwrap()) {
395 return Err(MustInlineToLegalize("panicking"));
396 }
397
398 let ret_ty = legal_globals
399 .get(&callee_def.result_type.unwrap())
400 .ok_or(MustInlineToLegalize("illegal return type"))?;
401 if !ret_ty.legal_as_fn_ret_ty() {
402 return Err(MustInlineToLegalize("illegal (pointer) return type"));
403 }
404
405 for (i, param) in callee.parameters.iter().enumerate() {
406 let param_ty = legal_globals
407 .get(param.result_type.as_ref().unwrap())
408 .ok_or(MustInlineToLegalize("illegal parameter type"))?;
409 if !param_ty.legal_as_fn_param_ty() {
410 return Err(MustInlineToLegalize("illegal (pointer) parameter type"));
411 }
412
413 if let LegalGlobal::TypePointer(_) = param_ty {
419 let ptr_arg = call_site.call_inst.operands[i + 1].unwrap_id_ref();
420 match legal_globals.get(&ptr_arg) {
421 Some(LegalGlobal::Variable) => {}
422
423 Some(_) => return Err(MustInlineToLegalize("illegal (pointer) argument")),
425
426 None => {
427 let mut caller_param_and_var_ids = call_site
428 .caller
429 .parameters
430 .iter()
431 .chain(
432 call_site.caller.blocks[0]
433 .instructions
434 .iter()
435 .filter(|caller_inst| {
436 let may_be_debuginfo = matches!(
440 caller_inst.class.opcode,
441 Op::Line | Op::NoLine | Op::ExtInst
442 );
443 !may_be_debuginfo
444 })
445 .take_while(|caller_inst| caller_inst.class.opcode == Op::Variable),
446 )
447 .map(|caller_inst| caller_inst.result_id.unwrap());
448
449 if !caller_param_and_var_ids.any(|id| ptr_arg == id) {
450 return Err(MustInlineToLegalize("illegal (pointer) argument"));
451 }
452 }
453 }
454 }
455 }
456
457 Ok(callee_control.contains(FunctionControl::INLINE))
458}
459
460#[derive(Debug)]
463struct FuncIsBeingInlined;
464
465struct Inliner<'a, 'b> {
472 custom_ext_inst_set_import: Word,
475
476 op_type_void_id: Word,
477
478 func_id_to_idx: FxHashMap<Word, usize>,
480
481 id_to_name: FxHashMap<Word, &'a str>,
484
485 cached_op_strings: FxHashMap<&'a str, Word>,
491
492 header: &'b mut ModuleHeader,
493 debug_string_source: &'b mut Vec<Instruction>,
494 annotations: &'b mut Vec<Instruction>,
495
496 legal_globals: FxHashMap<Word, LegalGlobal>,
497 functions_that_may_abort: FxHashSet<Word>,
498 inlined_dont_inlines_to_cause_and_callers: FxIndexMap<Word, (&'static str, FxIndexSet<Word>)>,
499 }
501
502impl Inliner<'_, '_> {
503 fn id(&mut self) -> Word {
504 next_id(self.header)
505 }
506
507 fn apply_rewrite_for_decorations(&mut self, rewrite_rules: &FxHashMap<Word, Word>) {
509 for annotation_idx in 0..self.annotations.len() {
512 let inst = &self.annotations[annotation_idx];
513 if let [Operand::IdRef(target), ..] = inst.operands[..]
514 && let Some(&rewritten_target) = rewrite_rules.get(&target)
515 {
516 let mut cloned_inst = inst.clone();
518 cloned_inst.operands[0] = Operand::IdRef(rewritten_target);
519 self.annotations.push(cloned_inst);
520 }
521 }
522 }
523
524 fn inline_fn(
525 &mut self,
526 function: &mut Function,
527 functions: &[Result<Function, FuncIsBeingInlined>],
528 ) {
529 let mut block_idx = 0;
530 while block_idx < function.blocks.len() {
531 if !self.inline_block(function, block_idx, functions) {
535 block_idx += 1;
537 }
538 }
539 }
540
541 fn inline_block(
542 &mut self,
543 caller: &mut Function,
544 block_idx: usize,
545 functions: &[Result<Function, FuncIsBeingInlined>],
546 ) -> bool {
547 let call = caller.blocks[block_idx]
549 .instructions
550 .iter()
551 .enumerate()
552 .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
553 .map(|(index, inst)| {
554 (
555 index,
556 inst,
557 functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
558 .as_ref()
559 .unwrap(),
560 )
561 })
562 .find(|(_, inst, f)| {
563 let call_site = CallSite {
564 caller,
565 call_inst: inst,
566 };
567 match should_inline(
568 &self.legal_globals,
569 &self.functions_that_may_abort,
570 f,
571 call_site,
572 ) {
573 Ok(inline) => inline,
574 Err(MustInlineToLegalize(cause)) => {
575 if has_dont_inline(f) {
576 self.inlined_dont_inlines_to_cause_and_callers
577 .entry(f.def_id().unwrap())
578 .or_insert_with(|| (cause, Default::default()))
579 .1
580 .insert(caller.def_id().unwrap());
581 }
582 true
583 }
584 }
585 });
586 let (call_index, call_inst, callee) = match call {
587 None => return false,
588 Some(call) => call,
589 };
590
591 if self
593 .functions_that_may_abort
594 .contains(&callee.def_id().unwrap())
595 {
596 self.functions_that_may_abort
597 .insert(caller.def_id().unwrap());
598 }
599
600 let mut maybe_call_result_phi = {
601 let ty = call_inst.result_type.unwrap();
602 if ty == self.op_type_void_id {
603 None
604 } else {
605 Some(Instruction::new(
606 Op::Phi,
607 Some(ty),
608 Some(call_inst.result_id.unwrap()),
609 vec![],
610 ))
611 }
612 };
613
614 let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
616 let call_debug_src_loc_inst = caller.blocks[block_idx].instructions[..call_index]
617 .iter()
618 .rev()
619 .find_map(|inst| {
620 Some(match inst.class.opcode {
621 Op::Line => Some(inst),
622 Op::NoLine => None,
623 Op::ExtInst
624 if inst.operands[0].unwrap_id_ref() == custom_ext_inst_set_import =>
625 {
626 match CustomOp::decode_from_ext_inst(inst) {
627 CustomOp::SetDebugSrcLoc => Some(inst),
628 CustomOp::ClearDebugSrcLoc => None,
629 _ => return None,
630 }
631 }
632 _ => return None,
633 })
634 })
635 .flatten();
636
637 let call_arguments = call_inst
639 .operands
640 .iter()
641 .skip(1)
642 .map(|op| op.id_ref_any().unwrap());
643 let callee_parameters = callee.parameters.iter().map(|inst| {
644 assert!(inst.class.opcode == Op::FunctionParameter);
645 inst.result_id.unwrap()
646 });
647 let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
648
649 let return_jump = self.id();
650 let mut inlined_callee_blocks = self.get_inlined_blocks(
652 callee,
653 call_debug_src_loc_inst,
654 maybe_call_result_phi.as_mut(),
655 return_jump,
656 );
657 self.add_clone_id_rules(&mut rewrite_rules, &inlined_callee_blocks);
660 apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks);
661 self.apply_rewrite_for_decorations(&rewrite_rules);
662
663 if let Some(call_result_phi) = &mut maybe_call_result_phi {
664 for op in &mut call_result_phi.operands {
667 if let Some(id) = op.id_ref_any_mut()
668 && let Some(&rewrite) = rewrite_rules.get(id)
669 {
670 *id = rewrite;
671 }
672 }
673
674 if let [returned_value, _return_block] = &call_result_phi.operands[..] {
678 let call_result_id = call_result_phi.result_id.unwrap();
679 let returned_value_id = returned_value.unwrap_id_ref();
680
681 maybe_call_result_phi = None;
682
683 let reaching_insts = {
686 let (pre_call_blocks, call_and_post_call_blocks) =
687 caller.blocks.split_at_mut(block_idx);
688 (pre_call_blocks.iter_mut().flat_map(|block| {
689 block
690 .instructions
691 .iter_mut()
692 .take_while(|inst| inst.class.opcode == Op::Phi)
693 }))
694 .chain(
695 call_and_post_call_blocks
696 .iter_mut()
697 .flat_map(|block| &mut block.instructions),
698 )
699 };
700 for reaching_inst in reaching_insts {
701 for op in &mut reaching_inst.operands {
702 if let Some(id) = op.id_ref_any_mut()
703 && *id == call_result_id
704 {
705 *id = returned_value_id;
706 }
707 }
708 }
709 }
710 }
711
712 let pre_call_block_idx = block_idx;
714 #[expect(unused)]
715 let block_idx: usize; let mut post_call_block_insts = caller.blocks[pre_call_block_idx]
717 .instructions
718 .split_off(call_index + 1);
719
720 let call = caller.blocks[pre_call_block_idx]
722 .instructions
723 .pop()
724 .unwrap();
725 assert!(call.class.opcode == Op::FunctionCall);
726
727 let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..);
729 let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len();
730 caller.blocks.splice(
731 (pre_call_block_idx + 1)..(pre_call_block_idx + 1),
732 non_entry_inlined_callee_blocks,
733 );
734
735 if let Some(call_result_phi) = maybe_call_result_phi {
736 post_call_block_insts.insert(0, call_result_phi);
738 }
739
740 {
742 let post_call_block_idx = pre_call_block_idx + num_non_entry_inlined_callee_blocks + 1;
743 let post_call_block = Block {
744 label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])),
745 instructions: post_call_block_insts,
746 };
747 caller.blocks.insert(post_call_block_idx, post_call_block);
748
749 rewrite_phi_sources(
752 caller.blocks[pre_call_block_idx].label_id().unwrap(),
753 &mut caller.blocks,
754 post_call_block_idx,
755 );
756 }
757
758 {
761 let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
764 let mut inst = inst.clone();
765 if let Some(id) = &mut inst.result_id {
766 *id = this.id();
767 }
768 inst
769 };
770
771 let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
772 Instruction::new(
773 Op::ExtInst,
774 Some(this.op_type_void_id),
775 Some(this.id()),
776 [
777 Operand::IdRef(this.custom_ext_inst_set_import),
778 Operand::LiteralExtInstInteger(inst.op() as u32),
779 ]
780 .into_iter()
781 .chain(inst.into_operands())
782 .collect(),
783 )
784 };
785
786 let mut steal_vars = |insts: &mut Vec<Instruction>| {
790 let mut enclosing_inlined_frames = SmallVec::<[_; 8]>::new();
794 let mut current_debug_src_loc_inst = None;
795 let mut vars_and_debuginfo_range = 0..0;
796 while vars_and_debuginfo_range.end < insts.len() {
797 let inst = &insts[vars_and_debuginfo_range.end];
798 match inst.class.opcode {
799 Op::Line => current_debug_src_loc_inst = Some(inst),
800 Op::NoLine => current_debug_src_loc_inst = None,
801 Op::ExtInst
802 if inst.operands[0].unwrap_id_ref()
803 == self.custom_ext_inst_set_import =>
804 {
805 match CustomOp::decode_from_ext_inst(inst) {
806 CustomOp::SetDebugSrcLoc => current_debug_src_loc_inst = Some(inst),
807 CustomOp::ClearDebugSrcLoc => current_debug_src_loc_inst = None,
808 CustomOp::PushInlinedCallFrame => {
809 enclosing_inlined_frames
810 .push((current_debug_src_loc_inst.take(), inst));
811 }
812 CustomOp::PopInlinedCallFrame => {
813 if let Some((callsite_debug_src_loc_inst, _)) =
814 enclosing_inlined_frames.pop()
815 {
816 current_debug_src_loc_inst = callsite_debug_src_loc_inst;
817 }
818 }
819 CustomOp::Abort => break,
820 }
821 }
822 Op::Variable => {}
823 _ => break,
824 }
825 vars_and_debuginfo_range.end += 1;
826 }
827
828 let all_pops_after_vars: SmallVec<[_; 8]> = enclosing_inlined_frames
840 .iter()
841 .map(|_| custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame))
842 .collect();
843 let all_repushes_before_non_vars: SmallVec<[_; 8]> =
844 (enclosing_inlined_frames.into_iter().flat_map(
845 |(callsite_debug_src_loc_inst, push_inlined_call_frame_inst)| {
846 (callsite_debug_src_loc_inst.into_iter())
847 .chain([push_inlined_call_frame_inst])
848 },
849 ))
850 .chain(current_debug_src_loc_inst)
851 .map(|inst| instantiate_debuginfo(self, inst))
852 .collect();
853
854 let vars_and_debuginfo =
855 insts.splice(vars_and_debuginfo_range, all_repushes_before_non_vars);
856 let repaired_vars_and_debuginfo = vars_and_debuginfo.chain(all_pops_after_vars);
857
858 repaired_vars_and_debuginfo.collect::<SmallVec<[_; 8]>>()
861 };
862
863 let [mut inlined_callee_entry_block]: [_; 1] =
864 inlined_callee_blocks.try_into().unwrap();
865
866 let callee_vars_and_debuginfo =
868 steal_vars(&mut inlined_callee_entry_block.instructions);
869 self.insert_opvariables(&mut caller.blocks[0], callee_vars_and_debuginfo);
870
871 caller.blocks[pre_call_block_idx]
872 .instructions
873 .append(&mut inlined_callee_entry_block.instructions);
874
875 rewrite_phi_sources(
879 inlined_callee_entry_block.label_id().unwrap(),
880 &mut caller.blocks,
881 pre_call_block_idx,
882 );
883 }
884
885 true
886 }
887
888 fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap<Word, Word>, blocks: &[Block]) {
889 for block in blocks {
890 for inst in block.label.iter().chain(&block.instructions) {
891 if let Some(result_id) = inst.result_id {
892 let new_id = self.id();
893 let old = rewrite_rules.insert(result_id, new_id);
894 assert!(old.is_none());
895 }
896 }
897 }
898 }
899
900 fn get_inlined_blocks(
901 &mut self,
902 callee: &Function,
903 call_debug_src_loc_inst: Option<&Instruction>,
904 mut maybe_call_result_phi: Option<&mut Instruction>,
905 return_jump: Word,
906 ) -> Vec<Block> {
907 let Self {
908 custom_ext_inst_set_import,
909 op_type_void_id,
910 ..
911 } = *self;
912
913 let callee_name = self
920 .id_to_name
921 .get(&callee.def_id().unwrap())
922 .copied()
923 .unwrap_or("");
924 let callee_name_id = *self
925 .cached_op_strings
926 .entry(callee_name)
927 .or_insert_with(|| {
928 let id = next_id(self.header);
929 self.debug_string_source.push(Instruction::new(
930 Op::String,
931 None,
932 Some(id),
933 vec![Operand::LiteralString(callee_name.to_string())],
934 ));
935 id
936 });
937 let mut mk_debuginfo_prefix_and_suffix = || {
938 let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
941 let mut inst = inst.clone();
942 if let Some(id) = &mut inst.result_id {
943 *id = this.id();
944 }
945 inst
946 };
947 let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
948 Instruction::new(
949 Op::ExtInst,
950 Some(op_type_void_id),
951 Some(this.id()),
952 [
953 Operand::IdRef(custom_ext_inst_set_import),
954 Operand::LiteralExtInstInteger(inst.op() as u32),
955 ]
956 .into_iter()
957 .chain(inst.into_operands())
958 .collect(),
959 )
960 };
961
962 (
963 (call_debug_src_loc_inst.map(|inst| instantiate_debuginfo(self, inst)))
964 .into_iter()
965 .chain([custom_inst_to_inst(
966 self,
967 CustomInst::PushInlinedCallFrame {
968 callee_name: Operand::IdRef(callee_name_id),
969 },
970 )]),
971 [custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame)],
972 )
973 };
974
975 let mut blocks = callee.blocks.clone();
976 for block in &mut blocks {
977 let mut terminator = block.instructions.pop().unwrap();
978
979 while let Some(last) = block.instructions.last() {
981 let can_remove = match last.class.opcode {
982 Op::Line | Op::NoLine => true,
983 Op::ExtInst => {
984 last.operands[0].unwrap_id_ref() == custom_ext_inst_set_import
985 && matches!(
986 CustomOp::decode_from_ext_inst(last),
987 CustomOp::SetDebugSrcLoc | CustomOp::ClearDebugSrcLoc
988 )
989 }
990 _ => false,
991 };
992 if can_remove {
993 block.instructions.pop();
994 } else {
995 break;
996 }
997 }
998
999 if let Op::Return | Op::ReturnValue = terminator.class.opcode {
1000 if Op::ReturnValue == terminator.class.opcode {
1001 let return_value = terminator.operands[0].id_ref_any().unwrap();
1002 let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap();
1003 call_result_phi.operands.extend([
1004 Operand::IdRef(return_value),
1005 Operand::IdRef(block.label_id().unwrap()),
1006 ]);
1007 } else {
1008 assert!(maybe_call_result_phi.is_none());
1009 }
1010 terminator =
1011 Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
1012 }
1013
1014 let num_phis = block
1015 .instructions
1016 .iter()
1017 .take_while(|inst| inst.class.opcode == Op::Phi)
1018 .count();
1019
1020 if block.instructions.len() > num_phis {
1022 let (debuginfo_prefix, debuginfo_suffix) = mk_debuginfo_prefix_and_suffix();
1023 block
1026 .instructions
1027 .splice(num_phis..num_phis, debuginfo_prefix);
1028 block.instructions.extend(debuginfo_suffix);
1031 }
1032
1033 block.instructions.push(terminator);
1034 }
1035
1036 blocks
1037 }
1038
1039 fn insert_opvariables(&self, block: &mut Block, insts: impl IntoIterator<Item = Instruction>) {
1040 let mut inlined_frames_depth = 0usize;
1044 let mut outermost_has_debug_src_loc = false;
1045 let mut last_debugless_var_insertion_point_candidate = None;
1046 for (i, inst) in block.instructions.iter().enumerate() {
1047 last_debugless_var_insertion_point_candidate =
1048 (inlined_frames_depth == 0 && !outermost_has_debug_src_loc).then_some(i);
1049
1050 let changed_has_debug_src_loc = match inst.class.opcode {
1051 Op::Line => true,
1052 Op::NoLine => false,
1053 Op::ExtInst
1054 if inst.operands[0].unwrap_id_ref() == self.custom_ext_inst_set_import =>
1055 {
1056 match CustomOp::decode_from_ext_inst(inst) {
1057 CustomOp::SetDebugSrcLoc => true,
1058 CustomOp::ClearDebugSrcLoc => false,
1059 CustomOp::PushInlinedCallFrame => {
1060 inlined_frames_depth += 1;
1061 continue;
1062 }
1063 CustomOp::PopInlinedCallFrame => {
1064 inlined_frames_depth = inlined_frames_depth.saturating_sub(1);
1065 continue;
1066 }
1067 CustomOp::Abort => break,
1068 }
1069 }
1070 Op::Variable => continue,
1071 _ => break,
1072 };
1073
1074 if inlined_frames_depth == 0 {
1075 outermost_has_debug_src_loc = changed_has_debug_src_loc;
1076 }
1077 }
1078
1079 let i = last_debugless_var_insertion_point_candidate.unwrap_or(0);
1082 block.instructions.splice(i..i, insts);
1083 }
1084}
1085
1086fn fuse_trivial_branches(function: &mut Function) {
1087 let all_preds = compute_preds(&function.blocks);
1088 'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
1089 let any_phis = function.blocks[dest_block]
1091 .instructions
1092 .iter()
1093 .filter(|inst| {
1094 !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1096 })
1097 .take_while(|inst| inst.class.opcode == Op::Phi)
1098 .next()
1099 .is_some();
1100 if any_phis {
1101 continue;
1102 }
1103
1104 let pred = loop {
1107 if preds.len() != 1 || preds[0] == dest_block {
1108 continue 'outer;
1109 }
1110 let pred = preds[0];
1111 if !function.blocks[pred].instructions.is_empty() {
1112 break pred;
1113 }
1114 preds = &all_preds[pred];
1115 };
1116 let pred_insts = &function.blocks[pred].instructions;
1117 if pred_insts.last().unwrap().class.opcode == Op::Branch {
1118 let mut dest_insts = mem::take(&mut function.blocks[dest_block].instructions);
1119 let pred_insts = &mut function.blocks[pred].instructions;
1120 pred_insts.pop(); pred_insts.append(&mut dest_insts);
1122
1123 rewrite_phi_sources(
1126 function.blocks[dest_block].label_id().unwrap(),
1127 &mut function.blocks,
1128 pred,
1129 );
1130 }
1131 }
1132 function.blocks.retain(|b| !b.instructions.is_empty());
1133}
1134
1135fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
1136 let mut result = vec![vec![]; blocks.len()];
1137 for (source_idx, source) in blocks.iter().enumerate() {
1138 for dest_id in outgoing_edges(source) {
1139 let dest_idx = blocks
1140 .iter()
1141 .position(|b| b.label_id().unwrap() == dest_id)
1142 .unwrap();
1143 result[dest_idx].push(source_idx);
1144 }
1145 }
1146 result
1147}
1148
1149fn rewrite_phi_sources(original_label_id: Word, blocks: &mut [Block], new_block_idx: usize) {
1152 let new_label_id = blocks[new_block_idx].label_id().unwrap();
1153
1154 let target_ids: SmallVec<[_; 4]> = outgoing_edges(&blocks[new_block_idx]).collect();
1156
1157 for target_id in target_ids {
1158 let target_block = blocks
1159 .iter_mut()
1160 .find(|b| b.label_id().unwrap() == target_id)
1161 .unwrap();
1162 let phis = target_block
1163 .instructions
1164 .iter_mut()
1165 .filter(|inst| {
1166 !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1168 })
1169 .take_while(|inst| inst.class.opcode == Op::Phi);
1170 for phi in phis {
1171 for value_and_source_id in phi.operands.chunks_mut(2) {
1172 let source_id = value_and_source_id[1].id_ref_any_mut().unwrap();
1173 if *source_id == original_label_id {
1174 *source_id = new_label_id;
1175 break;
1176 }
1177 }
1178 }
1179 }
1180}