rustc_codegen_spirv/linker/
inline.rs

1//! This algorithm is not intended to be an optimization, it is rather for legalization.
2//! Specifically, spir-v disallows things like a `StorageClass::Function` pointer to a
3//! `StorageClass::Input` pointer. Our frontend definitely allows it, though, this is like taking a
4//! `&Input<T>` in a function! So, we inline all functions that take these "illegal" pointers, then
5//! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer.
6
7use super::apply_rewrite_rules;
8use super::ipo::CallGraph;
9use super::simple_passes::outgoing_edges;
10use super::{get_name, get_names};
11use crate::custom_decorations::SpanRegenerator;
12use crate::custom_insts::{self, CustomInst, CustomOp};
13use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
14use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
15use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
16use rustc_errors::ErrorGuaranteed;
17use rustc_session::Session;
18use smallvec::SmallVec;
19use std::cmp::Ordering;
20use std::mem;
21
22// FIXME(eddyb) this is a bit silly, but this keeps being repeated everywhere.
23fn next_id(header: &mut ModuleHeader) -> Word {
24    let result = header.bound;
25    header.bound += 1;
26    result
27}
28
29pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
30    // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
31    deny_recursion_in_module(sess, module)?;
32
33    // Compute the call-graph that will drive (inside-out, aka bottom-up) inlining.
34    let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
35
36    let custom_ext_inst_set_import = module
37        .ext_inst_imports
38        .iter()
39        .find(|inst| {
40            assert_eq!(inst.class.opcode, Op::ExtInstImport);
41            inst.operands[0].unwrap_literal_string() == &custom_insts::CUSTOM_EXT_INST_SET[..]
42        })
43        .map(|inst| inst.result_id.unwrap());
44
45    let legal_globals = LegalGlobal::gather_from_module(module);
46
47    let header = module.header.as_mut().unwrap();
48
49    // FIXME(eddyb) clippy false positive (separate `map` required for borrowck).
50    #[allow(clippy::map_unwrap_or)]
51    let mut inliner = Inliner {
52        op_type_void_id: module
53            .types_global_values
54            .iter()
55            .find(|inst| inst.class.opcode == Op::TypeVoid)
56            .map(|inst| inst.result_id.unwrap())
57            .unwrap_or_else(|| {
58                let id = next_id(header);
59                let inst = Instruction::new(Op::TypeVoid, None, Some(id), vec![]);
60                module.types_global_values.push(inst);
61                id
62            }),
63
64        custom_ext_inst_set_import: custom_ext_inst_set_import.unwrap_or_else(|| {
65            let id = next_id(header);
66            let inst = Instruction::new(
67                Op::ExtInstImport,
68                None,
69                Some(id),
70                vec![Operand::LiteralString(
71                    custom_insts::CUSTOM_EXT_INST_SET.to_string(),
72                )],
73            );
74            module.ext_inst_imports.push(inst);
75            id
76        }),
77
78        func_id_to_idx,
79
80        id_to_name: module
81            .debug_names
82            .iter()
83            .filter(|inst| inst.class.opcode == Op::Name)
84            .map(|inst| {
85                (
86                    inst.operands[0].unwrap_id_ref(),
87                    inst.operands[1].unwrap_literal_string(),
88                )
89            })
90            .collect(),
91
92        cached_op_strings: FxHashMap::default(),
93
94        header,
95        debug_string_source: &mut module.debug_string_source,
96        annotations: &mut module.annotations,
97
98        legal_globals,
99
100        // NOTE(eddyb) this is needed because our custom `Abort` instructions get
101        // lowered to a simple `OpReturn` in entry-points, but that requires that
102        // they get inlined all the way up to the entry-points in the first place.
103        functions_that_may_abort: module
104            .functions
105            .iter()
106            .filter_map(|func| {
107                let custom_ext_inst_set_import = custom_ext_inst_set_import?;
108                func.blocks
109                    .iter()
110                    .any(|block| match &block.instructions[..] {
111                        [.., last_normal_inst, terminator_inst]
112                            if last_normal_inst.class.opcode == Op::ExtInst
113                                && last_normal_inst.operands[0].unwrap_id_ref()
114                                    == custom_ext_inst_set_import
115                                && CustomOp::decode_from_ext_inst(last_normal_inst)
116                                    == CustomOp::Abort =>
117                        {
118                            assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
119                            true
120                        }
121
122                        _ => false,
123                    })
124                    .then_some(func.def_id().unwrap())
125            })
126            .collect(),
127
128        inlined_dont_inlines_to_cause_and_callers: FxIndexMap::default(),
129    };
130
131    let mut functions: Vec<_> = mem::take(&mut module.functions)
132        .into_iter()
133        .map(Ok)
134        .collect();
135
136    // Inline functions in post-order (aka inside-out aka bottom-out) - that is,
137    // callees are processed before their callers, to avoid duplicating work.
138    for func_idx in call_graph.post_order() {
139        let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
140        inliner.inline_fn(&mut function, &functions);
141        fuse_trivial_branches(&mut function);
142        functions[func_idx] = Ok(function);
143    }
144
145    module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
146
147    let Inliner {
148        id_to_name,
149        inlined_dont_inlines_to_cause_and_callers,
150        ..
151    } = inliner;
152
153    let mut span_regen = SpanRegenerator::new(sess.source_map(), module);
154    for (callee_id, (cause, callers)) in inlined_dont_inlines_to_cause_and_callers {
155        let callee_name = get_name(&id_to_name, callee_id);
156
157        // HACK(eddyb) `libcore` hides panics behind `#[inline(never)]` `fn`s,
158        // making this too noisy and useless (since it's an impl detail).
159        if cause == "panicking" && callee_name.starts_with("core::") {
160            continue;
161        }
162
163        let callee_span = span_regen
164            .src_loc_for_id(callee_id)
165            .and_then(|src_loc| span_regen.src_loc_to_rustc(src_loc))
166            .unwrap_or_default();
167        sess.dcx()
168            .struct_span_warn(
169                callee_span,
170                format!("`#[inline(never)]` function `{callee_name}` has been inlined"),
171            )
172            .with_note(format!("inlining was required due to {cause}"))
173            .with_note(format!(
174                "called from {}",
175                callers
176                    .iter()
177                    .enumerate()
178                    .filter_map(|(i, &caller_id)| {
179                        // HACK(eddyb) avoid showing too many names.
180                        match i.cmp(&4) {
181                            Ordering::Less => {
182                                Some(format!("`{}`", get_name(&id_to_name, caller_id)))
183                            }
184                            Ordering::Equal => Some(format!("and {} more", callers.len() - i)),
185                            Ordering::Greater => None,
186                        }
187                    })
188                    .collect::<SmallVec<[_; 5]>>()
189                    .join(", ")
190            ))
191            .emit();
192    }
193
194    Ok(())
195}
196
197// https://stackoverflow.com/a/53995651
198fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> {
199    let func_to_index: FxHashMap<Word, usize> = module
200        .functions
201        .iter()
202        .enumerate()
203        .map(|(index, func)| (func.def_id().unwrap(), index))
204        .collect();
205    let mut discovered = vec![false; module.functions.len()];
206    let mut finished = vec![false; module.functions.len()];
207    let mut has_recursion = None;
208    for index in 0..module.functions.len() {
209        if !discovered[index] && !finished[index] {
210            visit(
211                sess,
212                module,
213                index,
214                &mut discovered,
215                &mut finished,
216                &mut has_recursion,
217                &func_to_index,
218            );
219        }
220    }
221
222    fn visit(
223        sess: &Session,
224        module: &Module,
225        current: usize,
226        discovered: &mut Vec<bool>,
227        finished: &mut Vec<bool>,
228        has_recursion: &mut Option<ErrorGuaranteed>,
229        func_to_index: &FxHashMap<Word, usize>,
230    ) {
231        discovered[current] = true;
232
233        for next in calls(&module.functions[current], func_to_index) {
234            if discovered[next] {
235                let names = get_names(module);
236                let current_name = get_name(&names, module.functions[current].def_id().unwrap());
237                let next_name = get_name(&names, module.functions[next].def_id().unwrap());
238                *has_recursion = Some(sess.dcx().err(format!(
239                    "module has recursion, which is not allowed: `{current_name}` calls `{next_name}`"
240                )));
241                break;
242            }
243
244            if !finished[next] {
245                visit(
246                    sess,
247                    module,
248                    next,
249                    discovered,
250                    finished,
251                    has_recursion,
252                    func_to_index,
253                );
254            }
255        }
256
257        discovered[current] = false;
258        finished[current] = true;
259    }
260
261    fn calls<'a>(
262        func: &'a Function,
263        func_to_index: &'a FxHashMap<Word, usize>,
264    ) -> impl Iterator<Item = usize> + 'a {
265        func.all_inst_iter()
266            .filter(|inst| inst.class.opcode == Op::FunctionCall)
267            .map(move |inst| {
268                *func_to_index
269                    .get(&inst.operands[0].id_ref_any().unwrap())
270                    .unwrap()
271            })
272    }
273
274    match has_recursion {
275        Some(err) => Err(err),
276        None => Ok(()),
277    }
278}
279
280/// Any type/const/global variable, which is "legal" (i.e. can be kept in SPIR-V).
281///
282/// For the purposes of the inliner, a legal global cannot:
283/// - refer to any illegal globals
284/// - (if a type) refer to any pointer types
285///   - this rules out both pointers in composites, and pointers to pointers
286///     (the latter itself *also* rules out variables containing pointers)
287enum LegalGlobal {
288    TypePointer(StorageClass),
289    TypeNonPointer,
290    Const,
291    Variable,
292}
293
294impl LegalGlobal {
295    fn gather_from_module(module: &Module) -> FxHashMap<Word, Self> {
296        let mut legal_globals = FxHashMap::<_, Self>::default();
297        for inst in &module.types_global_values {
298            let global = match inst.class.opcode {
299                Op::TypePointer => Self::TypePointer(inst.operands[0].unwrap_storage_class()),
300                Op::Variable => Self::Variable,
301                op if rspirv::grammar::reflect::is_type(op) => Self::TypeNonPointer,
302                op if rspirv::grammar::reflect::is_constant(op) => Self::Const,
303
304                // FIXME(eddyb) should this be `unreachable!()`?
305                _ => continue,
306            };
307            let legal_result_type = match inst.result_type {
308                Some(result_type_id) => matches!(
309                    (&global, legal_globals.get(&result_type_id)),
310                    (Self::Variable, Some(Self::TypePointer(_)))
311                        | (Self::Const, Some(Self::TypeNonPointer))
312                ),
313                None => matches!(global, Self::TypePointer(_) | Self::TypeNonPointer),
314            };
315            let legal_operands = inst.operands.iter().all(|operand| match operand {
316                Operand::IdRef(id) => matches!(
317                    legal_globals.get(id),
318                    Some(Self::TypeNonPointer | Self::Const)
319                ),
320
321                // NOTE(eddyb) this assumes non-ID operands are always legal.
322                _ => operand.id_ref_any().is_none(),
323            });
324            if legal_result_type && legal_operands {
325                legal_globals.insert(inst.result_id.unwrap(), global);
326            }
327        }
328        legal_globals
329    }
330
331    fn legal_as_fn_param_ty(&self) -> bool {
332        match *self {
333            Self::TypePointer(storage_class) => matches!(
334                storage_class,
335                StorageClass::UniformConstant
336                    | StorageClass::Function
337                    | StorageClass::Private
338                    | StorageClass::Workgroup
339                    | StorageClass::AtomicCounter
340            ),
341            Self::TypeNonPointer => true,
342
343            // FIXME(eddyb) should this be an `unreachable!()`?
344            Self::Const | Self::Variable => false,
345        }
346    }
347
348    fn legal_as_fn_ret_ty(&self) -> bool {
349        #[allow(clippy::match_same_arms)]
350        match *self {
351            Self::TypePointer(_) => false,
352            Self::TypeNonPointer => true,
353
354            // FIXME(eddyb) should this be an `unreachable!()`?
355            Self::Const | Self::Variable => false,
356        }
357    }
358}
359
360/// Helper type which encapsulates all the information about one specific call.
361#[derive(Copy, Clone)]
362struct CallSite<'a> {
363    caller: &'a Function,
364    call_inst: &'a Instruction,
365}
366
367fn has_dont_inline(function: &Function) -> bool {
368    let def = function.def.as_ref().unwrap();
369    let control = def.operands[0].unwrap_function_control();
370    control.contains(FunctionControl::DONT_INLINE)
371}
372
373/// Helper error type for `should_inline` (see its doc comment).
374#[derive(Copy, Clone, PartialEq, Eq)]
375struct MustInlineToLegalize(&'static str);
376
377/// Returns `Ok(true)`/`Err(MustInlineToLegalize(_))` if `callee` should/must be
378/// inlined (either in general, or specifically from `call_site`, if provided).
379///
380/// The distinction made here is that `Err(MustInlineToLegalize(cause))` is
381/// very much *not* a heuristic, and inlining is *mandatory* due to `cause`
382/// (usually illegal signature/arguments, but also the panicking mechanism).
383//
384// FIXME(eddyb) the causes here are not fine-grained enough.
385fn should_inline(
386    legal_globals: &FxHashMap<Word, LegalGlobal>,
387    functions_that_may_abort: &FxHashSet<Word>,
388    callee: &Function,
389    call_site: CallSite<'_>,
390) -> Result<bool, MustInlineToLegalize> {
391    let callee_def = callee.def.as_ref().unwrap();
392    let callee_control = callee_def.operands[0].unwrap_function_control();
393
394    if functions_that_may_abort.contains(&callee.def_id().unwrap()) {
395        return Err(MustInlineToLegalize("panicking"));
396    }
397
398    let ret_ty = legal_globals
399        .get(&callee_def.result_type.unwrap())
400        .ok_or(MustInlineToLegalize("illegal return type"))?;
401    if !ret_ty.legal_as_fn_ret_ty() {
402        return Err(MustInlineToLegalize("illegal (pointer) return type"));
403    }
404
405    for (i, param) in callee.parameters.iter().enumerate() {
406        let param_ty = legal_globals
407            .get(param.result_type.as_ref().unwrap())
408            .ok_or(MustInlineToLegalize("illegal parameter type"))?;
409        if !param_ty.legal_as_fn_param_ty() {
410            return Err(MustInlineToLegalize("illegal (pointer) parameter type"));
411        }
412
413        // If the call isn't passing a legal pointer argument (a "memory object",
414        // i.e. an `OpVariable` or one of the caller's `OpFunctionParameters),
415        // then inlining is required to have a chance at producing legal SPIR-V.
416        //
417        // FIXME(eddyb) rewriting away the pointer could be another alternative.
418        if let LegalGlobal::TypePointer(_) = param_ty {
419            let ptr_arg = call_site.call_inst.operands[i + 1].unwrap_id_ref();
420            match legal_globals.get(&ptr_arg) {
421                Some(LegalGlobal::Variable) => {}
422
423                // FIXME(eddyb) should some constants (undef/null) be allowed?
424                Some(_) => return Err(MustInlineToLegalize("illegal (pointer) argument")),
425
426                None => {
427                    let mut caller_param_and_var_ids = call_site
428                        .caller
429                        .parameters
430                        .iter()
431                        .chain(
432                            call_site.caller.blocks[0]
433                                .instructions
434                                .iter()
435                                .filter(|caller_inst| {
436                                    // HACK(eddyb) this only avoids scanning the
437                                    // whole entry block for `OpVariable`s, so
438                                    // it can overapproximate debuginfo insts.
439                                    let may_be_debuginfo = matches!(
440                                        caller_inst.class.opcode,
441                                        Op::Line | Op::NoLine | Op::ExtInst
442                                    );
443                                    !may_be_debuginfo
444                                })
445                                .take_while(|caller_inst| caller_inst.class.opcode == Op::Variable),
446                        )
447                        .map(|caller_inst| caller_inst.result_id.unwrap());
448
449                    if !caller_param_and_var_ids.any(|id| ptr_arg == id) {
450                        return Err(MustInlineToLegalize("illegal (pointer) argument"));
451                    }
452                }
453            }
454        }
455    }
456
457    Ok(callee_control.contains(FunctionControl::INLINE))
458}
459
460/// Helper error type for `Inliner`'s `functions` field, indicating a `Function`
461/// was taken out of its slot because it's being inlined.
462#[derive(Debug)]
463struct FuncIsBeingInlined;
464
465// Steps:
466// Move OpVariable decls
467// Rewrite return
468// Renumber IDs
469// Insert blocks
470
471struct Inliner<'a, 'b> {
472    /// ID of `OpExtInstImport` for our custom "extended instruction set"
473    /// (see `crate::custom_insts` for more details).
474    custom_ext_inst_set_import: Word,
475
476    op_type_void_id: Word,
477
478    /// Map from each function's ID to its index in `functions`.
479    func_id_to_idx: FxHashMap<Word, usize>,
480
481    /// Pre-collected `OpName`s, that can be used to find any function's name
482    /// during inlining (to be able to generate debuginfo that uses names).
483    id_to_name: FxHashMap<Word, &'a str>,
484
485    /// `OpString` cache (for deduplicating `OpString`s for the same string).
486    //
487    // FIXME(eddyb) currently this doesn't reuse existing `OpString`s, but since
488    // this is mostly for inlined callee names, it's expected almost no overlap
489    // exists between existing `OpString`s and new ones, anyway.
490    cached_op_strings: FxHashMap<&'a str, Word>,
491
492    header: &'b mut ModuleHeader,
493    debug_string_source: &'b mut Vec<Instruction>,
494    annotations: &'b mut Vec<Instruction>,
495
496    legal_globals: FxHashMap<Word, LegalGlobal>,
497    functions_that_may_abort: FxHashSet<Word>,
498    inlined_dont_inlines_to_cause_and_callers: FxIndexMap<Word, (&'static str, FxIndexSet<Word>)>,
499    // rewrite_rules: FxHashMap<Word, Word>,
500}
501
502impl Inliner<'_, '_> {
503    fn id(&mut self) -> Word {
504        next_id(self.header)
505    }
506
507    /// Applies all rewrite rules to the decorations in the header.
508    fn apply_rewrite_for_decorations(&mut self, rewrite_rules: &FxHashMap<Word, Word>) {
509        // NOTE(siebencorgie): We don't care *what* decoration we rewrite atm.
510        // AFAIK there is no case where keeping decorations on inline wouldn't be valid.
511        for annotation_idx in 0..self.annotations.len() {
512            let inst = &self.annotations[annotation_idx];
513            if let [Operand::IdRef(target), ..] = inst.operands[..]
514                && let Some(&rewritten_target) = rewrite_rules.get(&target)
515            {
516                // Copy decoration instruction and push it.
517                let mut cloned_inst = inst.clone();
518                cloned_inst.operands[0] = Operand::IdRef(rewritten_target);
519                self.annotations.push(cloned_inst);
520            }
521        }
522    }
523
524    fn inline_fn(
525        &mut self,
526        function: &mut Function,
527        functions: &[Result<Function, FuncIsBeingInlined>],
528    ) {
529        let mut block_idx = 0;
530        while block_idx < function.blocks.len() {
531            // If we successfully inlined a block, then repeat processing on the same block, in
532            // case the newly inlined block has more inlined calls.
533            // TODO: This is quadratic
534            if !self.inline_block(function, block_idx, functions) {
535                // TODO(eddyb) skip past the inlined callee without rescanning it.
536                block_idx += 1;
537            }
538        }
539    }
540
541    fn inline_block(
542        &mut self,
543        caller: &mut Function,
544        block_idx: usize,
545        functions: &[Result<Function, FuncIsBeingInlined>],
546    ) -> bool {
547        // Find the first inlined OpFunctionCall
548        let call = caller.blocks[block_idx]
549            .instructions
550            .iter()
551            .enumerate()
552            .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
553            .map(|(index, inst)| {
554                (
555                    index,
556                    inst,
557                    functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
558                        .as_ref()
559                        .unwrap(),
560                )
561            })
562            .find(|(_, inst, f)| {
563                let call_site = CallSite {
564                    caller,
565                    call_inst: inst,
566                };
567                match should_inline(
568                    &self.legal_globals,
569                    &self.functions_that_may_abort,
570                    f,
571                    call_site,
572                ) {
573                    Ok(inline) => inline,
574                    Err(MustInlineToLegalize(cause)) => {
575                        if has_dont_inline(f) {
576                            self.inlined_dont_inlines_to_cause_and_callers
577                                .entry(f.def_id().unwrap())
578                                .or_insert_with(|| (cause, Default::default()))
579                                .1
580                                .insert(caller.def_id().unwrap());
581                        }
582                        true
583                    }
584                }
585            });
586        let (call_index, call_inst, callee) = match call {
587            None => return false,
588            Some(call) => call,
589        };
590
591        // Propagate "may abort" from callee to caller (i.e. as aborts get inlined).
592        if self
593            .functions_that_may_abort
594            .contains(&callee.def_id().unwrap())
595        {
596            self.functions_that_may_abort
597                .insert(caller.def_id().unwrap());
598        }
599
600        let mut maybe_call_result_phi = {
601            let ty = call_inst.result_type.unwrap();
602            if ty == self.op_type_void_id {
603                None
604            } else {
605                Some(Instruction::new(
606                    Op::Phi,
607                    Some(ty),
608                    Some(call_inst.result_id.unwrap()),
609                    vec![],
610                ))
611            }
612        };
613
614        // Get the debug "source location" instruction that applies to the call.
615        let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
616        let call_debug_src_loc_inst = caller.blocks[block_idx].instructions[..call_index]
617            .iter()
618            .rev()
619            .find_map(|inst| {
620                Some(match inst.class.opcode {
621                    Op::Line => Some(inst),
622                    Op::NoLine => None,
623                    Op::ExtInst
624                        if inst.operands[0].unwrap_id_ref() == custom_ext_inst_set_import =>
625                    {
626                        match CustomOp::decode_from_ext_inst(inst) {
627                            CustomOp::SetDebugSrcLoc => Some(inst),
628                            CustomOp::ClearDebugSrcLoc => None,
629                            _ => return None,
630                        }
631                    }
632                    _ => return None,
633                })
634            })
635            .flatten();
636
637        // Rewrite parameters to arguments
638        let call_arguments = call_inst
639            .operands
640            .iter()
641            .skip(1)
642            .map(|op| op.id_ref_any().unwrap());
643        let callee_parameters = callee.parameters.iter().map(|inst| {
644            assert!(inst.class.opcode == Op::FunctionParameter);
645            inst.result_id.unwrap()
646        });
647        let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
648
649        let return_jump = self.id();
650        // Rewrite OpReturns of the callee.
651        let mut inlined_callee_blocks = self.get_inlined_blocks(
652            callee,
653            call_debug_src_loc_inst,
654            maybe_call_result_phi.as_mut(),
655            return_jump,
656        );
657        // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
658        // fn is inlined multiple times.
659        self.add_clone_id_rules(&mut rewrite_rules, &inlined_callee_blocks);
660        apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks);
661        self.apply_rewrite_for_decorations(&rewrite_rules);
662
663        if let Some(call_result_phi) = &mut maybe_call_result_phi {
664            // HACK(eddyb) new IDs should be generated earlier, to avoid pushing
665            // callee IDs to `call_result_phi.operands` only to rewrite them here.
666            for op in &mut call_result_phi.operands {
667                if let Some(id) = op.id_ref_any_mut()
668                    && let Some(&rewrite) = rewrite_rules.get(id)
669                {
670                    *id = rewrite;
671                }
672            }
673
674            // HACK(eddyb) this special-casing of the single-return case is
675            // really necessary for passes like `mem2reg` which are not capable
676            // of skipping through the extraneous `OpPhi`s on their own.
677            if let [returned_value, _return_block] = &call_result_phi.operands[..] {
678                let call_result_id = call_result_phi.result_id.unwrap();
679                let returned_value_id = returned_value.unwrap_id_ref();
680
681                maybe_call_result_phi = None;
682
683                // HACK(eddyb) this is a conservative approximation of all the
684                // instructions that could potentially reference the call result.
685                let reaching_insts = {
686                    let (pre_call_blocks, call_and_post_call_blocks) =
687                        caller.blocks.split_at_mut(block_idx);
688                    (pre_call_blocks.iter_mut().flat_map(|block| {
689                        block
690                            .instructions
691                            .iter_mut()
692                            .take_while(|inst| inst.class.opcode == Op::Phi)
693                    }))
694                    .chain(
695                        call_and_post_call_blocks
696                            .iter_mut()
697                            .flat_map(|block| &mut block.instructions),
698                    )
699                };
700                for reaching_inst in reaching_insts {
701                    for op in &mut reaching_inst.operands {
702                        if let Some(id) = op.id_ref_any_mut()
703                            && *id == call_result_id
704                        {
705                            *id = returned_value_id;
706                        }
707                    }
708                }
709            }
710        }
711
712        // Split the block containing the `OpFunctionCall` into pre-call vs post-call.
713        let pre_call_block_idx = block_idx;
714        #[expect(unused)]
715        let block_idx: usize; // HACK(eddyb) disallowing using the unrenamed variable.
716        let mut post_call_block_insts = caller.blocks[pre_call_block_idx]
717            .instructions
718            .split_off(call_index + 1);
719
720        // pop off OpFunctionCall
721        let call = caller.blocks[pre_call_block_idx]
722            .instructions
723            .pop()
724            .unwrap();
725        assert!(call.class.opcode == Op::FunctionCall);
726
727        // Insert non-entry inlined callee blocks just after the pre-call block.
728        let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..);
729        let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len();
730        caller.blocks.splice(
731            (pre_call_block_idx + 1)..(pre_call_block_idx + 1),
732            non_entry_inlined_callee_blocks,
733        );
734
735        if let Some(call_result_phi) = maybe_call_result_phi {
736            // Add the `OpPhi` for the call result value, after the inlined function.
737            post_call_block_insts.insert(0, call_result_phi);
738        }
739
740        // Insert the post-call block, after all the inlined callee blocks.
741        {
742            let post_call_block_idx = pre_call_block_idx + num_non_entry_inlined_callee_blocks + 1;
743            let post_call_block = Block {
744                label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])),
745                instructions: post_call_block_insts,
746            };
747            caller.blocks.insert(post_call_block_idx, post_call_block);
748
749            // Adjust any `OpPhi`s in the (caller) targets of the original call block,
750            // to refer to post-call block (the new source of those CFG edges).
751            rewrite_phi_sources(
752                caller.blocks[pre_call_block_idx].label_id().unwrap(),
753                &mut caller.blocks,
754                post_call_block_idx,
755            );
756        }
757
758        // Fuse the inlined callee entry block into the pre-call block.
759        // This is okay because it's illegal to branch to the first BB in a function.
760        {
761            // NOTE(eddyb) `OpExtInst`s have a result ID, even if unused, and
762            // it has to be unique, so this allocates new IDs as-needed.
763            let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
764                let mut inst = inst.clone();
765                if let Some(id) = &mut inst.result_id {
766                    *id = this.id();
767                }
768                inst
769            };
770
771            let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
772                Instruction::new(
773                    Op::ExtInst,
774                    Some(this.op_type_void_id),
775                    Some(this.id()),
776                    [
777                        Operand::IdRef(this.custom_ext_inst_set_import),
778                        Operand::LiteralExtInstInteger(inst.op() as u32),
779                    ]
780                    .into_iter()
781                    .chain(inst.into_operands())
782                    .collect(),
783                )
784            };
785
786            // Return the subsequence of `insts` made from `OpVariable`s, and any
787            // debuginfo instructions (which may apply to them), while removing
788            // *only* `OpVariable`s from `insts` (and keeping debuginfo in both).
789            let mut steal_vars = |insts: &mut Vec<Instruction>| {
790                // HACK(eddyb) this duplicates some code from `get_inlined_blocks`,
791                // but that will be removed once the inliner is refactored to be
792                // inside-out instead of outside-in (already finished in a branch).
793                let mut enclosing_inlined_frames = SmallVec::<[_; 8]>::new();
794                let mut current_debug_src_loc_inst = None;
795                let mut vars_and_debuginfo_range = 0..0;
796                while vars_and_debuginfo_range.end < insts.len() {
797                    let inst = &insts[vars_and_debuginfo_range.end];
798                    match inst.class.opcode {
799                        Op::Line => current_debug_src_loc_inst = Some(inst),
800                        Op::NoLine => current_debug_src_loc_inst = None,
801                        Op::ExtInst
802                            if inst.operands[0].unwrap_id_ref()
803                                == self.custom_ext_inst_set_import =>
804                        {
805                            match CustomOp::decode_from_ext_inst(inst) {
806                                CustomOp::SetDebugSrcLoc => current_debug_src_loc_inst = Some(inst),
807                                CustomOp::ClearDebugSrcLoc => current_debug_src_loc_inst = None,
808                                CustomOp::PushInlinedCallFrame => {
809                                    enclosing_inlined_frames
810                                        .push((current_debug_src_loc_inst.take(), inst));
811                                }
812                                CustomOp::PopInlinedCallFrame => {
813                                    if let Some((callsite_debug_src_loc_inst, _)) =
814                                        enclosing_inlined_frames.pop()
815                                    {
816                                        current_debug_src_loc_inst = callsite_debug_src_loc_inst;
817                                    }
818                                }
819                                CustomOp::Abort => break,
820                            }
821                        }
822                        Op::Variable => {}
823                        _ => break,
824                    }
825                    vars_and_debuginfo_range.end += 1;
826                }
827
828                // `vars_and_debuginfo_range.end` indicates where `OpVariable`s
829                // end and other instructions start (module debuginfo), but to
830                // split the block in two, both sides of the "cut" need "repair":
831                // - the variables are missing "inlined call frames" pops, that
832                //   may happen later in the block, and have to be synthesized
833                // - the non-variables are missing "inlined call frames" pushes,
834                //   that must be recreated to avoid ending up with dangling pops
835                //
836                // FIXME(eddyb) this only collects to avoid borrow conflicts,
837                // between e.g. `enclosing_inlined_frames` and mutating `insts`,
838                // but also between different uses of `self`.
839                let all_pops_after_vars: SmallVec<[_; 8]> = enclosing_inlined_frames
840                    .iter()
841                    .map(|_| custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame))
842                    .collect();
843                let all_repushes_before_non_vars: SmallVec<[_; 8]> =
844                    (enclosing_inlined_frames.into_iter().flat_map(
845                        |(callsite_debug_src_loc_inst, push_inlined_call_frame_inst)| {
846                            (callsite_debug_src_loc_inst.into_iter())
847                                .chain([push_inlined_call_frame_inst])
848                        },
849                    ))
850                    .chain(current_debug_src_loc_inst)
851                    .map(|inst| instantiate_debuginfo(self, inst))
852                    .collect();
853
854                let vars_and_debuginfo =
855                    insts.splice(vars_and_debuginfo_range, all_repushes_before_non_vars);
856                let repaired_vars_and_debuginfo = vars_and_debuginfo.chain(all_pops_after_vars);
857
858                // FIXME(eddyb) collecting shouldn't be necessary but this is
859                // nested in a closure, and `splice` borrows the original `Vec`.
860                repaired_vars_and_debuginfo.collect::<SmallVec<[_; 8]>>()
861            };
862
863            let [mut inlined_callee_entry_block]: [_; 1] =
864                inlined_callee_blocks.try_into().unwrap();
865
866            // Move the `OpVariable`s of the callee to the caller.
867            let callee_vars_and_debuginfo =
868                steal_vars(&mut inlined_callee_entry_block.instructions);
869            self.insert_opvariables(&mut caller.blocks[0], callee_vars_and_debuginfo);
870
871            caller.blocks[pre_call_block_idx]
872                .instructions
873                .append(&mut inlined_callee_entry_block.instructions);
874
875            // Adjust any `OpPhi`s in the (inlined callee) targets of the
876            // inlined callee entry block, to refer to the pre-call block
877            // (the new source of those CFG edges).
878            rewrite_phi_sources(
879                inlined_callee_entry_block.label_id().unwrap(),
880                &mut caller.blocks,
881                pre_call_block_idx,
882            );
883        }
884
885        true
886    }
887
888    fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap<Word, Word>, blocks: &[Block]) {
889        for block in blocks {
890            for inst in block.label.iter().chain(&block.instructions) {
891                if let Some(result_id) = inst.result_id {
892                    let new_id = self.id();
893                    let old = rewrite_rules.insert(result_id, new_id);
894                    assert!(old.is_none());
895                }
896            }
897        }
898    }
899
900    fn get_inlined_blocks(
901        &mut self,
902        callee: &Function,
903        call_debug_src_loc_inst: Option<&Instruction>,
904        mut maybe_call_result_phi: Option<&mut Instruction>,
905        return_jump: Word,
906    ) -> Vec<Block> {
907        let Self {
908            custom_ext_inst_set_import,
909            op_type_void_id,
910            ..
911        } = *self;
912
913        // Prepare the debuginfo insts to prepend/append to every block.
914        // FIXME(eddyb) this could be more efficient if we only used one pair of
915        // `{Push,Pop}InlinedCallFrame` for the whole inlined callee, but there
916        // is no way to hint the SPIR-T CFG (re)structurizer that it should keep
917        // the entire callee in one region - a SPIR-T inliner wouldn't have this
918        // issue, as it would require a fully structured callee.
919        let callee_name = self
920            .id_to_name
921            .get(&callee.def_id().unwrap())
922            .copied()
923            .unwrap_or("");
924        let callee_name_id = *self
925            .cached_op_strings
926            .entry(callee_name)
927            .or_insert_with(|| {
928                let id = next_id(self.header);
929                self.debug_string_source.push(Instruction::new(
930                    Op::String,
931                    None,
932                    Some(id),
933                    vec![Operand::LiteralString(callee_name.to_string())],
934                ));
935                id
936            });
937        let mut mk_debuginfo_prefix_and_suffix = || {
938            // NOTE(eddyb) `OpExtInst`s have a result ID, even if unused, and
939            // it has to be unique (same goes for the other instructions below).
940            let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
941                let mut inst = inst.clone();
942                if let Some(id) = &mut inst.result_id {
943                    *id = this.id();
944                }
945                inst
946            };
947            let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
948                Instruction::new(
949                    Op::ExtInst,
950                    Some(op_type_void_id),
951                    Some(this.id()),
952                    [
953                        Operand::IdRef(custom_ext_inst_set_import),
954                        Operand::LiteralExtInstInteger(inst.op() as u32),
955                    ]
956                    .into_iter()
957                    .chain(inst.into_operands())
958                    .collect(),
959                )
960            };
961
962            (
963                (call_debug_src_loc_inst.map(|inst| instantiate_debuginfo(self, inst)))
964                    .into_iter()
965                    .chain([custom_inst_to_inst(
966                        self,
967                        CustomInst::PushInlinedCallFrame {
968                            callee_name: Operand::IdRef(callee_name_id),
969                        },
970                    )]),
971                [custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame)],
972            )
973        };
974
975        let mut blocks = callee.blocks.clone();
976        for block in &mut blocks {
977            let mut terminator = block.instructions.pop().unwrap();
978
979            // HACK(eddyb) strip trailing debuginfo (as it can't impact terminators).
980            while let Some(last) = block.instructions.last() {
981                let can_remove = match last.class.opcode {
982                    Op::Line | Op::NoLine => true,
983                    Op::ExtInst => {
984                        last.operands[0].unwrap_id_ref() == custom_ext_inst_set_import
985                            && matches!(
986                                CustomOp::decode_from_ext_inst(last),
987                                CustomOp::SetDebugSrcLoc | CustomOp::ClearDebugSrcLoc
988                            )
989                    }
990                    _ => false,
991                };
992                if can_remove {
993                    block.instructions.pop();
994                } else {
995                    break;
996                }
997            }
998
999            if let Op::Return | Op::ReturnValue = terminator.class.opcode {
1000                if Op::ReturnValue == terminator.class.opcode {
1001                    let return_value = terminator.operands[0].id_ref_any().unwrap();
1002                    let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap();
1003                    call_result_phi.operands.extend([
1004                        Operand::IdRef(return_value),
1005                        Operand::IdRef(block.label_id().unwrap()),
1006                    ]);
1007                } else {
1008                    assert!(maybe_call_result_phi.is_none());
1009                }
1010                terminator =
1011                    Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
1012            }
1013
1014            let num_phis = block
1015                .instructions
1016                .iter()
1017                .take_while(|inst| inst.class.opcode == Op::Phi)
1018                .count();
1019
1020            // HACK(eddyb) avoid adding debuginfo to otherwise-empty blocks.
1021            if block.instructions.len() > num_phis {
1022                let (debuginfo_prefix, debuginfo_suffix) = mk_debuginfo_prefix_and_suffix();
1023                // Insert the prefix debuginfo instructions after `OpPhi`s,
1024                // which sadly can't be covered by them.
1025                block
1026                    .instructions
1027                    .splice(num_phis..num_phis, debuginfo_prefix);
1028                // Insert the suffix debuginfo instructions before the terminator,
1029                // which sadly can't be covered by them.
1030                block.instructions.extend(debuginfo_suffix);
1031            }
1032
1033            block.instructions.push(terminator);
1034        }
1035
1036        blocks
1037    }
1038
1039    fn insert_opvariables(&self, block: &mut Block, insts: impl IntoIterator<Item = Instruction>) {
1040        // HACK(eddyb) this isn't as efficient as it could be in theory, but it's
1041        // very important to make sure sure to never insert new instructions in
1042        // the middle of debuginfo (as it would be affected by it).
1043        let mut inlined_frames_depth = 0usize;
1044        let mut outermost_has_debug_src_loc = false;
1045        let mut last_debugless_var_insertion_point_candidate = None;
1046        for (i, inst) in block.instructions.iter().enumerate() {
1047            last_debugless_var_insertion_point_candidate =
1048                (inlined_frames_depth == 0 && !outermost_has_debug_src_loc).then_some(i);
1049
1050            let changed_has_debug_src_loc = match inst.class.opcode {
1051                Op::Line => true,
1052                Op::NoLine => false,
1053                Op::ExtInst
1054                    if inst.operands[0].unwrap_id_ref() == self.custom_ext_inst_set_import =>
1055                {
1056                    match CustomOp::decode_from_ext_inst(inst) {
1057                        CustomOp::SetDebugSrcLoc => true,
1058                        CustomOp::ClearDebugSrcLoc => false,
1059                        CustomOp::PushInlinedCallFrame => {
1060                            inlined_frames_depth += 1;
1061                            continue;
1062                        }
1063                        CustomOp::PopInlinedCallFrame => {
1064                            inlined_frames_depth = inlined_frames_depth.saturating_sub(1);
1065                            continue;
1066                        }
1067                        CustomOp::Abort => break,
1068                    }
1069                }
1070                Op::Variable => continue,
1071                _ => break,
1072            };
1073
1074            if inlined_frames_depth == 0 {
1075                outermost_has_debug_src_loc = changed_has_debug_src_loc;
1076            }
1077        }
1078
1079        // HACK(eddyb) fallback to inserting at the start, which should be correct.
1080        // FIXME(eddyb) some level of debuginfo repair could prevent needing this.
1081        let i = last_debugless_var_insertion_point_candidate.unwrap_or(0);
1082        block.instructions.splice(i..i, insts);
1083    }
1084}
1085
1086fn fuse_trivial_branches(function: &mut Function) {
1087    let all_preds = compute_preds(&function.blocks);
1088    'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
1089        // Don't fuse branches into blocks with `OpPhi`s.
1090        let any_phis = function.blocks[dest_block]
1091            .instructions
1092            .iter()
1093            .filter(|inst| {
1094                // These are the only instructions that are allowed before `OpPhi`.
1095                !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1096            })
1097            .take_while(|inst| inst.class.opcode == Op::Phi)
1098            .next()
1099            .is_some();
1100        if any_phis {
1101            continue;
1102        }
1103
1104        // if there's two trivial branches in a row, the middle one might get inlined before the
1105        // last one, so when processing the last one, skip through to the first one.
1106        let pred = loop {
1107            if preds.len() != 1 || preds[0] == dest_block {
1108                continue 'outer;
1109            }
1110            let pred = preds[0];
1111            if !function.blocks[pred].instructions.is_empty() {
1112                break pred;
1113            }
1114            preds = &all_preds[pred];
1115        };
1116        let pred_insts = &function.blocks[pred].instructions;
1117        if pred_insts.last().unwrap().class.opcode == Op::Branch {
1118            let mut dest_insts = mem::take(&mut function.blocks[dest_block].instructions);
1119            let pred_insts = &mut function.blocks[pred].instructions;
1120            pred_insts.pop(); // pop the branch
1121            pred_insts.append(&mut dest_insts);
1122
1123            // Adjust any `OpPhi`s in the targets of the original block, to refer
1124            // to the sole predecessor (the new source of those CFG edges).
1125            rewrite_phi_sources(
1126                function.blocks[dest_block].label_id().unwrap(),
1127                &mut function.blocks,
1128                pred,
1129            );
1130        }
1131    }
1132    function.blocks.retain(|b| !b.instructions.is_empty());
1133}
1134
1135fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
1136    let mut result = vec![vec![]; blocks.len()];
1137    for (source_idx, source) in blocks.iter().enumerate() {
1138        for dest_id in outgoing_edges(source) {
1139            let dest_idx = blocks
1140                .iter()
1141                .position(|b| b.label_id().unwrap() == dest_id)
1142                .unwrap();
1143            result[dest_idx].push(source_idx);
1144        }
1145    }
1146    result
1147}
1148
1149/// Helper for adjusting `OpPhi` source label IDs, when the terminator of the
1150/// `original_label_id`-labeled block got moved to `blocks[original_block_idx]`.
1151fn rewrite_phi_sources(original_label_id: Word, blocks: &mut [Block], new_block_idx: usize) {
1152    let new_label_id = blocks[new_block_idx].label_id().unwrap();
1153
1154    // HACK(eddyb) can't keep `blocks` borrowed, the loop needs mutable access.
1155    let target_ids: SmallVec<[_; 4]> = outgoing_edges(&blocks[new_block_idx]).collect();
1156
1157    for target_id in target_ids {
1158        let target_block = blocks
1159            .iter_mut()
1160            .find(|b| b.label_id().unwrap() == target_id)
1161            .unwrap();
1162        let phis = target_block
1163            .instructions
1164            .iter_mut()
1165            .filter(|inst| {
1166                // These are the only instructions that are allowed before `OpPhi`.
1167                !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1168            })
1169            .take_while(|inst| inst.class.opcode == Op::Phi);
1170        for phi in phis {
1171            for value_and_source_id in phi.operands.chunks_mut(2) {
1172                let source_id = value_and_source_id[1].id_ref_any_mut().unwrap();
1173                if *source_id == original_label_id {
1174                    *source_id = new_label_id;
1175                    break;
1176                }
1177            }
1178        }
1179    }
1180}