Skip to main content

rustc_codegen_spirv/linker/
inline.rs

1//! This algorithm is not intended to be an optimization, it is rather for legalization.
2//! Specifically, spir-v disallows things like a `StorageClass::Function` pointer to a
3//! `StorageClass::Input` pointer. Our frontend definitely allows it, though, this is like taking a
4//! `&Input<T>` in a function! So, we inline all functions that take these "illegal" pointers, then
5//! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer.
6
7use super::apply_rewrite_rules;
8use super::ipo::CallGraph;
9use super::simple_passes::outgoing_edges;
10use super::{get_name, get_names};
11use crate::custom_decorations::SpanRegenerator;
12use crate::custom_insts::{self, CustomInst, CustomOp};
13use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
14use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
15use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
16use rustc_errors::ErrorGuaranteed;
17use rustc_session::Session;
18use smallvec::SmallVec;
19use std::cmp::Ordering;
20use std::mem;
21
22// FIXME(eddyb) this is a bit silly, but this keeps being repeated everywhere.
23fn next_id(header: &mut ModuleHeader) -> Word {
24    let result = header.bound;
25    header.bound += 1;
26    result
27}
28
29pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
30    // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
31    deny_recursion_in_module(sess, module)?;
32
33    // Compute the call-graph that will drive (inside-out, aka bottom-up) inlining.
34    let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
35
36    let custom_ext_inst_set_import = module
37        .ext_inst_imports
38        .iter()
39        .find(|inst| {
40            assert_eq!(inst.class.opcode, Op::ExtInstImport);
41            inst.operands[0].unwrap_literal_string() == &custom_insts::CUSTOM_EXT_INST_SET[..]
42        })
43        .map(|inst| inst.result_id.unwrap());
44
45    let legal_globals = LegalGlobal::gather_from_module(module);
46
47    let header = module.header.as_mut().unwrap();
48
49    // FIXME(eddyb) clippy false positive (separate `map` required for borrowck).
50    #[allow(clippy::map_unwrap_or)]
51    let mut inliner = Inliner {
52        op_type_void_id: module
53            .types_global_values
54            .iter()
55            .find(|inst| inst.class.opcode == Op::TypeVoid)
56            .map(|inst| inst.result_id.unwrap())
57            .unwrap_or_else(|| {
58                let id = next_id(header);
59                let inst = Instruction::new(Op::TypeVoid, None, Some(id), vec![]);
60                module.types_global_values.push(inst);
61                id
62            }),
63
64        custom_ext_inst_set_import: custom_ext_inst_set_import.unwrap_or_else(|| {
65            let id = next_id(header);
66            let inst = Instruction::new(
67                Op::ExtInstImport,
68                None,
69                Some(id),
70                vec![Operand::LiteralString(
71                    custom_insts::CUSTOM_EXT_INST_SET.to_string(),
72                )],
73            );
74            module.ext_inst_imports.push(inst);
75            id
76        }),
77
78        func_id_to_idx,
79
80        id_to_name: module
81            .debug_names
82            .iter()
83            .filter(|inst| inst.class.opcode == Op::Name)
84            .map(|inst| {
85                (
86                    inst.operands[0].unwrap_id_ref(),
87                    inst.operands[1].unwrap_literal_string(),
88                )
89            })
90            .collect(),
91
92        cached_op_strings: FxHashMap::default(),
93
94        header,
95        debug_string_source: &mut module.debug_string_source,
96        annotations: &mut module.annotations,
97
98        legal_globals,
99
100        // NOTE(eddyb) this is needed because our custom `Abort` instructions get
101        // lowered to a simple `OpReturn` in entry-points, but that requires that
102        // they get inlined all the way up to the entry-points in the first place.
103        functions_that_may_abort: module
104            .functions
105            .iter()
106            .filter_map(|func| {
107                let custom_ext_inst_set_import = custom_ext_inst_set_import?;
108                func.blocks
109                    .iter()
110                    .any(|block| match &block.instructions[..] {
111                        [.., last_normal_inst, terminator_inst]
112                            if last_normal_inst.class.opcode == Op::ExtInst
113                                && last_normal_inst.operands[0].unwrap_id_ref()
114                                    == custom_ext_inst_set_import
115                                && CustomOp::decode_from_ext_inst(last_normal_inst)
116                                    == CustomOp::Abort =>
117                        {
118                            assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
119                            true
120                        }
121
122                        _ => false,
123                    })
124                    .then_some(func.def_id().unwrap())
125            })
126            .collect(),
127
128        inlined_dont_inlines_to_cause_and_callers: FxIndexMap::default(),
129    };
130
131    let mut functions: Vec<_> = mem::take(&mut module.functions)
132        .into_iter()
133        .map(Ok)
134        .collect();
135
136    let mut mem2reg_pointer_to_pointee = FxHashMap::default();
137    let mut mem2reg_constants = FxHashMap::default();
138    {
139        let mut u32 = None;
140        for inst in &module.types_global_values {
141            match inst.class.opcode {
142                Op::TypePointer => {
143                    mem2reg_pointer_to_pointee
144                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
145                }
146                Op::TypeInt
147                    if inst.operands[0].unwrap_literal_bit32() == 32
148                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
149                {
150                    assert!(u32.is_none());
151                    u32 = Some(inst.result_id.unwrap());
152                }
153                Op::Constant if u32.is_some() && inst.result_type == u32 => {
154                    let value = inst.operands[0].unwrap_literal_bit32();
155                    mem2reg_constants.insert(inst.result_id.unwrap(), value);
156                }
157                _ => {}
158            }
159        }
160    }
161
162    // Inline functions in post-order (aka inside-out aka bottom-up) - that is,
163    // callees are processed before their callers, to avoid duplicating work.
164    for func_idx in call_graph.post_order() {
165        let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
166        inliner.inline_fn(&mut function, &functions);
167        fuse_trivial_branches(&mut function);
168
169        super::duplicates::DebuginfoDeduplicator {
170            custom_ext_inst_set_import,
171        }
172        .remove_duplicate_debuginfo_in_function(&mut function);
173
174        {
175            super::simple_passes::block_ordering_pass(&mut function);
176            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
177            super::mem2reg::mem2reg(
178                inliner.header,
179                &mut module.types_global_values,
180                &mem2reg_pointer_to_pointee,
181                &mem2reg_constants,
182                &mut function,
183            );
184            super::destructure_composites::destructure_composites(&mut function);
185        }
186
187        functions[func_idx] = Ok(function);
188    }
189
190    module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
191
192    let Inliner {
193        id_to_name,
194        inlined_dont_inlines_to_cause_and_callers,
195        ..
196    } = inliner;
197
198    let mut span_regen = SpanRegenerator::new(sess.source_map(), module);
199    for (callee_id, (cause, callers)) in inlined_dont_inlines_to_cause_and_callers {
200        let callee_name = get_name(&id_to_name, callee_id);
201
202        // HACK(eddyb) `libcore` hides panics behind `#[inline(never)]` `fn`s,
203        // making this too noisy and useless (since it's an impl detail).
204        if cause == "panicking" && callee_name.starts_with("core::") {
205            continue;
206        }
207
208        let callee_span = span_regen
209            .src_loc_for_id(callee_id)
210            .and_then(|src_loc| span_regen.src_loc_to_rustc(src_loc))
211            .unwrap_or_default();
212        sess.dcx()
213            .struct_span_warn(
214                callee_span,
215                format!("`#[inline(never)]` function `{callee_name}` has been inlined"),
216            )
217            .with_note(format!("inlining was required due to {cause}"))
218            .with_note(format!(
219                "called from {}",
220                callers
221                    .iter()
222                    .enumerate()
223                    .filter_map(|(i, &caller_id)| {
224                        // HACK(eddyb) avoid showing too many names.
225                        match i.cmp(&4) {
226                            Ordering::Less => {
227                                Some(format!("`{}`", get_name(&id_to_name, caller_id)))
228                            }
229                            Ordering::Equal => Some(format!("and {} more", callers.len() - i)),
230                            Ordering::Greater => None,
231                        }
232                    })
233                    .collect::<SmallVec<[_; 5]>>()
234                    .join(", ")
235            ))
236            .emit();
237    }
238
239    Ok(())
240}
241
242// https://stackoverflow.com/a/53995651
243fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> {
244    let func_to_index: FxHashMap<Word, usize> = module
245        .functions
246        .iter()
247        .enumerate()
248        .map(|(index, func)| (func.def_id().unwrap(), index))
249        .collect();
250    let mut discovered = vec![false; module.functions.len()];
251    let mut finished = vec![false; module.functions.len()];
252    let mut has_recursion = None;
253    for index in 0..module.functions.len() {
254        if !discovered[index] && !finished[index] {
255            visit(
256                sess,
257                module,
258                index,
259                &mut discovered,
260                &mut finished,
261                &mut has_recursion,
262                &func_to_index,
263            );
264        }
265    }
266
267    fn visit(
268        sess: &Session,
269        module: &Module,
270        current: usize,
271        discovered: &mut Vec<bool>,
272        finished: &mut Vec<bool>,
273        has_recursion: &mut Option<ErrorGuaranteed>,
274        func_to_index: &FxHashMap<Word, usize>,
275    ) {
276        discovered[current] = true;
277
278        for next in calls(&module.functions[current], func_to_index) {
279            if discovered[next] {
280                let names = get_names(module);
281                let current_name = get_name(&names, module.functions[current].def_id().unwrap());
282                let next_name = get_name(&names, module.functions[next].def_id().unwrap());
283                *has_recursion = Some(sess.dcx().err(format!(
284                    "module has recursion, which is not allowed: `{current_name}` calls `{next_name}`"
285                )));
286                break;
287            }
288
289            if !finished[next] {
290                visit(
291                    sess,
292                    module,
293                    next,
294                    discovered,
295                    finished,
296                    has_recursion,
297                    func_to_index,
298                );
299            }
300        }
301
302        discovered[current] = false;
303        finished[current] = true;
304    }
305
306    fn calls<'a>(
307        func: &'a Function,
308        func_to_index: &'a FxHashMap<Word, usize>,
309    ) -> impl Iterator<Item = usize> + 'a {
310        func.all_inst_iter()
311            .filter(|inst| inst.class.opcode == Op::FunctionCall)
312            .map(move |inst| {
313                *func_to_index
314                    .get(&inst.operands[0].id_ref_any().unwrap())
315                    .unwrap()
316            })
317    }
318
319    match has_recursion {
320        Some(err) => Err(err),
321        None => Ok(()),
322    }
323}
324
325/// Any type/const/global variable, which is "legal" (i.e. can be kept in SPIR-V).
326///
327/// For the purposes of the inliner, a legal global cannot:
328/// - refer to any illegal globals
329/// - (if a type) refer to any pointer types
330///   - this rules out both pointers in composites, and pointers to pointers
331///     (the latter itself *also* rules out variables containing pointers)
332enum LegalGlobal {
333    TypePointer(StorageClass),
334    TypeNonPointer,
335    Const,
336    Variable,
337}
338
339impl LegalGlobal {
340    fn gather_from_module(module: &Module) -> FxHashMap<Word, Self> {
341        let mut legal_globals = FxHashMap::<_, Self>::default();
342        for inst in &module.types_global_values {
343            let global = match inst.class.opcode {
344                Op::TypePointer => Self::TypePointer(inst.operands[0].unwrap_storage_class()),
345                Op::Variable => Self::Variable,
346                op if rspirv::grammar::reflect::is_type(op) => Self::TypeNonPointer,
347                op if rspirv::grammar::reflect::is_constant(op) => Self::Const,
348
349                // FIXME(eddyb) should this be `unreachable!()`?
350                _ => continue,
351            };
352            let legal_result_type = match inst.result_type {
353                Some(result_type_id) => matches!(
354                    (&global, legal_globals.get(&result_type_id)),
355                    (Self::Variable, Some(Self::TypePointer(_)))
356                        | (Self::Const, Some(Self::TypeNonPointer))
357                ),
358                None => matches!(global, Self::TypePointer(_) | Self::TypeNonPointer),
359            };
360            let legal_operands = inst.operands.iter().all(|operand| match operand {
361                Operand::IdRef(id) => matches!(
362                    legal_globals.get(id),
363                    Some(Self::TypeNonPointer | Self::Const)
364                ),
365
366                // NOTE(eddyb) this assumes non-ID operands are always legal.
367                _ => operand.id_ref_any().is_none(),
368            });
369            if legal_result_type && legal_operands {
370                legal_globals.insert(inst.result_id.unwrap(), global);
371            }
372        }
373        legal_globals
374    }
375
376    fn legal_as_fn_param_ty(&self) -> bool {
377        match *self {
378            Self::TypePointer(storage_class) => matches!(
379                storage_class,
380                StorageClass::UniformConstant
381                    | StorageClass::Function
382                    | StorageClass::Private
383                    | StorageClass::Workgroup
384                    | StorageClass::AtomicCounter
385            ),
386            Self::TypeNonPointer => true,
387
388            // FIXME(eddyb) should this be an `unreachable!()`?
389            Self::Const | Self::Variable => false,
390        }
391    }
392
393    fn legal_as_fn_ret_ty(&self) -> bool {
394        #[allow(clippy::match_same_arms)]
395        match *self {
396            Self::TypePointer(_) => false,
397            Self::TypeNonPointer => true,
398
399            // FIXME(eddyb) should this be an `unreachable!()`?
400            Self::Const | Self::Variable => false,
401        }
402    }
403}
404
405/// Helper type which encapsulates all the information about one specific call.
406#[derive(Copy, Clone)]
407struct CallSite<'a> {
408    caller: &'a Function,
409    call_inst: &'a Instruction,
410}
411
412fn has_dont_inline(function: &Function) -> bool {
413    let def = function.def.as_ref().unwrap();
414    let control = def.operands[0].unwrap_function_control();
415    control.contains(FunctionControl::DONT_INLINE)
416}
417
418/// Helper error type for `should_inline` (see its doc comment).
419#[derive(Copy, Clone, PartialEq, Eq)]
420struct MustInlineToLegalize(&'static str);
421
422/// Returns `Ok(true)`/`Err(MustInlineToLegalize(_))` if `callee` should/must be
423/// inlined (either in general, or specifically from `call_site`, if provided).
424///
425/// The distinction made here is that `Err(MustInlineToLegalize(cause))` is
426/// very much *not* a heuristic, and inlining is *mandatory* due to `cause`
427/// (usually illegal signature/arguments, but also the panicking mechanism).
428//
429// FIXME(eddyb) the causes here are not fine-grained enough.
430fn should_inline(
431    legal_globals: &FxHashMap<Word, LegalGlobal>,
432    functions_that_may_abort: &FxHashSet<Word>,
433    callee: &Function,
434    call_site: CallSite<'_>,
435) -> Result<bool, MustInlineToLegalize> {
436    let callee_def = callee.def.as_ref().unwrap();
437    let callee_control = callee_def.operands[0].unwrap_function_control();
438
439    if functions_that_may_abort.contains(&callee.def_id().unwrap()) {
440        return Err(MustInlineToLegalize("panicking"));
441    }
442
443    let ret_ty = legal_globals
444        .get(&callee_def.result_type.unwrap())
445        .ok_or(MustInlineToLegalize("illegal return type"))?;
446    if !ret_ty.legal_as_fn_ret_ty() {
447        return Err(MustInlineToLegalize("illegal (pointer) return type"));
448    }
449
450    for (i, param) in callee.parameters.iter().enumerate() {
451        let param_ty = legal_globals
452            .get(param.result_type.as_ref().unwrap())
453            .ok_or(MustInlineToLegalize("illegal parameter type"))?;
454        if !param_ty.legal_as_fn_param_ty() {
455            return Err(MustInlineToLegalize("illegal (pointer) parameter type"));
456        }
457
458        // If the call isn't passing a legal pointer argument (a "memory object",
459        // i.e. an `OpVariable` or one of the caller's `OpFunctionParameter`s),
460        // then inlining is required to have a chance at producing legal SPIR-V.
461        //
462        // FIXME(eddyb) rewriting away the pointer could be another alternative.
463        if let LegalGlobal::TypePointer(_) = param_ty {
464            let ptr_arg = call_site.call_inst.operands[i + 1].unwrap_id_ref();
465            match legal_globals.get(&ptr_arg) {
466                Some(LegalGlobal::Variable) => {}
467
468                // FIXME(eddyb) should some constants (undef/null) be allowed?
469                Some(_) => return Err(MustInlineToLegalize("illegal (pointer) argument")),
470
471                None => {
472                    let mut caller_param_and_var_ids = call_site
473                        .caller
474                        .parameters
475                        .iter()
476                        .chain(
477                            call_site.caller.blocks[0]
478                                .instructions
479                                .iter()
480                                .filter(|caller_inst| {
481                                    // HACK(eddyb) this only avoids scanning the
482                                    // whole entry block for `OpVariable`s, so
483                                    // it can overapproximate debuginfo insts.
484                                    let may_be_debuginfo = matches!(
485                                        caller_inst.class.opcode,
486                                        Op::Line | Op::NoLine | Op::ExtInst
487                                    );
488                                    !may_be_debuginfo
489                                })
490                                .take_while(|caller_inst| caller_inst.class.opcode == Op::Variable),
491                        )
492                        .map(|caller_inst| caller_inst.result_id.unwrap());
493
494                    if !caller_param_and_var_ids.any(|id| ptr_arg == id) {
495                        return Err(MustInlineToLegalize("illegal (pointer) argument"));
496                    }
497                }
498            }
499        }
500    }
501
502    Ok(callee_control.contains(FunctionControl::INLINE))
503}
504
505/// Helper error type for `Inliner`'s `functions` field, indicating a `Function`
506/// was taken out of its slot because it's being inlined.
507#[derive(Debug)]
508struct FuncIsBeingInlined;
509
510// Steps:
511// Move OpVariable decls
512// Rewrite return
513// Renumber IDs
514// Insert blocks
515
516struct Inliner<'a, 'b> {
517    /// ID of `OpExtInstImport` for our custom "extended instruction set"
518    /// (see `crate::custom_insts` for more details).
519    custom_ext_inst_set_import: Word,
520
521    op_type_void_id: Word,
522
523    /// Map from each function's ID to its index in `functions`.
524    func_id_to_idx: FxHashMap<Word, usize>,
525
526    /// Pre-collected `OpName`s, that can be used to find any function's name
527    /// during inlining (to be able to generate debuginfo that uses names).
528    id_to_name: FxHashMap<Word, &'a str>,
529
530    /// `OpString` cache (for deduplicating `OpString`s for the same string).
531    //
532    // FIXME(eddyb) currently this doesn't reuse existing `OpString`s, but since
533    // this is mostly for inlined callee names, it's expected almost no overlap
534    // exists between existing `OpString`s and new ones, anyway.
535    cached_op_strings: FxHashMap<&'a str, Word>,
536
537    header: &'b mut ModuleHeader,
538    debug_string_source: &'b mut Vec<Instruction>,
539    annotations: &'b mut Vec<Instruction>,
540
541    legal_globals: FxHashMap<Word, LegalGlobal>,
542    functions_that_may_abort: FxHashSet<Word>,
543    inlined_dont_inlines_to_cause_and_callers: FxIndexMap<Word, (&'static str, FxIndexSet<Word>)>,
544    // rewrite_rules: FxHashMap<Word, Word>,
545}
546
547impl Inliner<'_, '_> {
548    fn id(&mut self) -> Word {
549        next_id(self.header)
550    }
551
552    /// Applies all rewrite rules to the decorations in the header.
553    fn apply_rewrite_for_decorations(&mut self, rewrite_rules: &FxHashMap<Word, Word>) {
554        // NOTE(siebencorgie): We don't care *what* decoration we rewrite atm.
555        // AFAIK there is no case where keeping decorations on inline wouldn't be valid.
556        for annotation_idx in 0..self.annotations.len() {
557            let inst = &self.annotations[annotation_idx];
558            if let [Operand::IdRef(target), ..] = inst.operands[..]
559                && let Some(&rewritten_target) = rewrite_rules.get(&target)
560            {
561                // Copy decoration instruction and push it.
562                let mut cloned_inst = inst.clone();
563                cloned_inst.operands[0] = Operand::IdRef(rewritten_target);
564                self.annotations.push(cloned_inst);
565            }
566        }
567    }
568
569    fn inline_fn(
570        &mut self,
571        function: &mut Function,
572        functions: &[Result<Function, FuncIsBeingInlined>],
573    ) {
574        let mut block_idx = 0;
575        while block_idx < function.blocks.len() {
576            // If we successfully inlined a block, then repeat processing on the same block, in
577            // case the newly inlined block has more inlined calls.
578            // TODO: This is quadratic
579            if !self.inline_block(function, block_idx, functions) {
580                // TODO(eddyb) skip past the inlined callee without rescanning it.
581                block_idx += 1;
582            }
583        }
584    }
585
586    fn inline_block(
587        &mut self,
588        caller: &mut Function,
589        block_idx: usize,
590        functions: &[Result<Function, FuncIsBeingInlined>],
591    ) -> bool {
592        // Find the first inlined OpFunctionCall
593        let call = caller.blocks[block_idx]
594            .instructions
595            .iter()
596            .enumerate()
597            .filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
598            .map(|(index, inst)| {
599                (
600                    index,
601                    inst,
602                    functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
603                        .as_ref()
604                        .unwrap(),
605                )
606            })
607            .find(|(_, inst, f)| {
608                let call_site = CallSite {
609                    caller,
610                    call_inst: inst,
611                };
612                match should_inline(
613                    &self.legal_globals,
614                    &self.functions_that_may_abort,
615                    f,
616                    call_site,
617                ) {
618                    Ok(inline) => inline,
619                    Err(MustInlineToLegalize(cause)) => {
620                        if has_dont_inline(f) {
621                            self.inlined_dont_inlines_to_cause_and_callers
622                                .entry(f.def_id().unwrap())
623                                .or_insert_with(|| (cause, Default::default()))
624                                .1
625                                .insert(caller.def_id().unwrap());
626                        }
627                        true
628                    }
629                }
630            });
631        let (call_index, call_inst, callee) = match call {
632            None => return false,
633            Some(call) => call,
634        };
635
636        // Propagate "may abort" from callee to caller (i.e. as aborts get inlined).
637        if self
638            .functions_that_may_abort
639            .contains(&callee.def_id().unwrap())
640        {
641            self.functions_that_may_abort
642                .insert(caller.def_id().unwrap());
643        }
644
645        let mut maybe_call_result_phi = {
646            let ty = call_inst.result_type.unwrap();
647            if ty == self.op_type_void_id {
648                None
649            } else {
650                Some(Instruction::new(
651                    Op::Phi,
652                    Some(ty),
653                    Some(call_inst.result_id.unwrap()),
654                    vec![],
655                ))
656            }
657        };
658
659        // Get the debug "source location" instruction that applies to the call.
660        let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
661        let call_debug_src_loc_inst = caller.blocks[block_idx].instructions[..call_index]
662            .iter()
663            .rev()
664            .find_map(|inst| {
665                Some(match inst.class.opcode {
666                    Op::Line => Some(inst),
667                    Op::NoLine => None,
668                    Op::ExtInst
669                        if inst.operands[0].unwrap_id_ref() == custom_ext_inst_set_import =>
670                    {
671                        match CustomOp::decode_from_ext_inst(inst) {
672                            CustomOp::SetDebugSrcLoc => Some(inst),
673                            CustomOp::ClearDebugSrcLoc => None,
674                            _ => return None,
675                        }
676                    }
677                    _ => return None,
678                })
679            })
680            .flatten();
681
682        // Rewrite parameters to arguments
683        let call_arguments = call_inst
684            .operands
685            .iter()
686            .skip(1)
687            .map(|op| op.id_ref_any().unwrap());
688        let callee_parameters = callee.parameters.iter().map(|inst| {
689            assert!(inst.class.opcode == Op::FunctionParameter);
690            inst.result_id.unwrap()
691        });
692        let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
693
694        let return_jump = self.id();
695        // Rewrite OpReturns of the callee.
696        let mut inlined_callee_blocks = self.get_inlined_blocks(
697            callee,
698            call_debug_src_loc_inst,
699            maybe_call_result_phi.as_mut(),
700            return_jump,
701        );
702        // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
703        // fn is inlined multiple times.
704        self.add_clone_id_rules(&mut rewrite_rules, &inlined_callee_blocks);
705        apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks);
706        self.apply_rewrite_for_decorations(&rewrite_rules);
707
708        if let Some(call_result_phi) = &mut maybe_call_result_phi {
709            // HACK(eddyb) new IDs should be generated earlier, to avoid pushing
710            // callee IDs to `call_result_phi.operands` only to rewrite them here.
711            for op in &mut call_result_phi.operands {
712                if let Some(id) = op.id_ref_any_mut()
713                    && let Some(&rewrite) = rewrite_rules.get(id)
714                {
715                    *id = rewrite;
716                }
717            }
718
719            // HACK(eddyb) this special-casing of the single-return case is
720            // really necessary for passes like `mem2reg` which are not capable
721            // of skipping through the extraneous `OpPhi`s on their own.
722            if let [returned_value, _return_block] = &call_result_phi.operands[..] {
723                let call_result_id = call_result_phi.result_id.unwrap();
724                let returned_value_id = returned_value.unwrap_id_ref();
725
726                maybe_call_result_phi = None;
727
728                // HACK(eddyb) this is a conservative approximation of all the
729                // instructions that could potentially reference the call result.
730                let reaching_insts = {
731                    let (pre_call_blocks, call_and_post_call_blocks) =
732                        caller.blocks.split_at_mut(block_idx);
733                    (pre_call_blocks.iter_mut().flat_map(|block| {
734                        block
735                            .instructions
736                            .iter_mut()
737                            .take_while(|inst| inst.class.opcode == Op::Phi)
738                    }))
739                    .chain(
740                        call_and_post_call_blocks
741                            .iter_mut()
742                            .flat_map(|block| &mut block.instructions),
743                    )
744                };
745                for reaching_inst in reaching_insts {
746                    for op in &mut reaching_inst.operands {
747                        if let Some(id) = op.id_ref_any_mut()
748                            && *id == call_result_id
749                        {
750                            *id = returned_value_id;
751                        }
752                    }
753                }
754            }
755        }
756
757        // Split the block containing the `OpFunctionCall` into pre-call vs post-call.
758        let pre_call_block_idx = block_idx;
759        #[expect(unused)]
760        let block_idx: usize; // HACK(eddyb) disallowing using the unrenamed variable.
761        let mut post_call_block_insts = caller.blocks[pre_call_block_idx]
762            .instructions
763            .split_off(call_index + 1);
764
765        // pop off OpFunctionCall
766        let call = caller.blocks[pre_call_block_idx]
767            .instructions
768            .pop()
769            .unwrap();
770        assert!(call.class.opcode == Op::FunctionCall);
771
772        // Insert non-entry inlined callee blocks just after the pre-call block.
773        let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..);
774        let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len();
775        caller.blocks.splice(
776            (pre_call_block_idx + 1)..(pre_call_block_idx + 1),
777            non_entry_inlined_callee_blocks,
778        );
779
780        if let Some(call_result_phi) = maybe_call_result_phi {
781            // Add the `OpPhi` for the call result value, after the inlined function.
782            post_call_block_insts.insert(0, call_result_phi);
783        }
784
785        // Insert the post-call block, after all the inlined callee blocks.
786        {
787            let post_call_block_idx = pre_call_block_idx + num_non_entry_inlined_callee_blocks + 1;
788            let post_call_block = Block {
789                label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])),
790                instructions: post_call_block_insts,
791            };
792            caller.blocks.insert(post_call_block_idx, post_call_block);
793
794            // Adjust any `OpPhi`s in the (caller) targets of the original call block,
795            // to refer to post-call block (the new source of those CFG edges).
796            rewrite_phi_sources(
797                caller.blocks[pre_call_block_idx].label_id().unwrap(),
798                &mut caller.blocks,
799                post_call_block_idx,
800            );
801        }
802
803        // Fuse the inlined callee entry block into the pre-call block.
804        // This is okay because it's illegal to branch to the first BB in a function.
805        {
806            // NOTE(eddyb) `OpExtInst`s have a result ID, even if unused, and
807            // it has to be unique, so this allocates new IDs as-needed.
808            let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
809                let mut inst = inst.clone();
810                if let Some(id) = &mut inst.result_id {
811                    *id = this.id();
812                }
813                inst
814            };
815
816            let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
817                Instruction::new(
818                    Op::ExtInst,
819                    Some(this.op_type_void_id),
820                    Some(this.id()),
821                    [
822                        Operand::IdRef(this.custom_ext_inst_set_import),
823                        Operand::LiteralExtInstInteger(inst.op() as u32),
824                    ]
825                    .into_iter()
826                    .chain(inst.into_operands())
827                    .collect(),
828                )
829            };
830
831            // Return the subsequence of `insts` made from `OpVariable`s, and any
832            // debuginfo instructions (which may apply to them), while removing
833            // *only* `OpVariable`s from `insts` (and keeping debuginfo in both).
834            let mut steal_vars = |insts: &mut Vec<Instruction>| {
835                // HACK(eddyb) this duplicates some code from `get_inlined_blocks`,
836                // but that will be removed once the inliner is refactored to be
837                // inside-out instead of outside-in (already finished in a branch).
838                let mut enclosing_inlined_frames = SmallVec::<[_; 8]>::new();
839                let mut current_debug_src_loc_inst = None;
840                let mut vars_and_debuginfo_range = 0..0;
841                while vars_and_debuginfo_range.end < insts.len() {
842                    let inst = &insts[vars_and_debuginfo_range.end];
843                    match inst.class.opcode {
844                        Op::Line => current_debug_src_loc_inst = Some(inst),
845                        Op::NoLine => current_debug_src_loc_inst = None,
846                        Op::ExtInst
847                            if inst.operands[0].unwrap_id_ref()
848                                == self.custom_ext_inst_set_import =>
849                        {
850                            match CustomOp::decode_from_ext_inst(inst) {
851                                CustomOp::SetDebugSrcLoc => current_debug_src_loc_inst = Some(inst),
852                                CustomOp::ClearDebugSrcLoc => current_debug_src_loc_inst = None,
853                                CustomOp::PushInlinedCallFrame => {
854                                    enclosing_inlined_frames
855                                        .push((current_debug_src_loc_inst.take(), inst));
856                                }
857                                CustomOp::PopInlinedCallFrame => {
858                                    if let Some((callsite_debug_src_loc_inst, _)) =
859                                        enclosing_inlined_frames.pop()
860                                    {
861                                        current_debug_src_loc_inst = callsite_debug_src_loc_inst;
862                                    }
863                                }
864                                CustomOp::Abort => break,
865                            }
866                        }
867                        Op::Variable => {}
868                        _ => break,
869                    }
870                    vars_and_debuginfo_range.end += 1;
871                }
872
873                // `vars_and_debuginfo_range.end` indicates where `OpVariable`s
874                // end and other instructions start (modulo debuginfo), but to
875                // split the block in two, both sides of the "cut" need "repair":
876                // - the variables are missing "inlined call frames" pops, that
877                //   may happen later in the block, and have to be synthesized
878                // - the non-variables are missing "inlined call frames" pushes,
879                //   that must be recreated to avoid ending up with dangling pops
880                //
881                // FIXME(eddyb) this only collects to avoid borrow conflicts,
882                // between e.g. `enclosing_inlined_frames` and mutating `insts`,
883                // but also between different uses of `self`.
884                let all_pops_after_vars: SmallVec<[_; 8]> = enclosing_inlined_frames
885                    .iter()
886                    .map(|_| custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame))
887                    .collect();
888                let all_repushes_before_non_vars: SmallVec<[_; 8]> =
889                    (enclosing_inlined_frames.into_iter().flat_map(
890                        |(callsite_debug_src_loc_inst, push_inlined_call_frame_inst)| {
891                            (callsite_debug_src_loc_inst.into_iter())
892                                .chain([push_inlined_call_frame_inst])
893                        },
894                    ))
895                    .chain(current_debug_src_loc_inst)
896                    .map(|inst| instantiate_debuginfo(self, inst))
897                    .collect();
898
899                let vars_and_debuginfo =
900                    insts.splice(vars_and_debuginfo_range, all_repushes_before_non_vars);
901                let repaired_vars_and_debuginfo = vars_and_debuginfo.chain(all_pops_after_vars);
902
903                // FIXME(eddyb) collecting shouldn't be necessary but this is
904                // nested in a closure, and `splice` borrows the original `Vec`.
905                repaired_vars_and_debuginfo.collect::<SmallVec<[_; 8]>>()
906            };
907
908            let [mut inlined_callee_entry_block]: [_; 1] =
909                inlined_callee_blocks.try_into().unwrap();
910
911            // Move the `OpVariable`s of the callee to the caller.
912            let callee_vars_and_debuginfo =
913                steal_vars(&mut inlined_callee_entry_block.instructions);
914            self.insert_opvariables(&mut caller.blocks[0], callee_vars_and_debuginfo);
915
916            caller.blocks[pre_call_block_idx]
917                .instructions
918                .append(&mut inlined_callee_entry_block.instructions);
919
920            // Adjust any `OpPhi`s in the (inlined callee) targets of the
921            // inlined callee entry block, to refer to the pre-call block
922            // (the new source of those CFG edges).
923            rewrite_phi_sources(
924                inlined_callee_entry_block.label_id().unwrap(),
925                &mut caller.blocks,
926                pre_call_block_idx,
927            );
928        }
929
930        true
931    }
932
933    fn add_clone_id_rules(&mut self, rewrite_rules: &mut FxHashMap<Word, Word>, blocks: &[Block]) {
934        for block in blocks {
935            for inst in block.label.iter().chain(&block.instructions) {
936                if let Some(result_id) = inst.result_id {
937                    let new_id = self.id();
938                    let old = rewrite_rules.insert(result_id, new_id);
939                    assert!(old.is_none());
940                }
941            }
942        }
943    }
944
945    fn get_inlined_blocks(
946        &mut self,
947        callee: &Function,
948        call_debug_src_loc_inst: Option<&Instruction>,
949        mut maybe_call_result_phi: Option<&mut Instruction>,
950        return_jump: Word,
951    ) -> Vec<Block> {
952        let Self {
953            custom_ext_inst_set_import,
954            op_type_void_id,
955            ..
956        } = *self;
957
958        // Prepare the debuginfo insts to prepend/append to every block.
959        // FIXME(eddyb) this could be more efficient if we only used one pair of
960        // `{Push,Pop}InlinedCallFrame` for the whole inlined callee, but there
961        // is no way to hint the SPIR-T CFG (re)structurizer that it should keep
962        // the entire callee in one region - a SPIR-T inliner wouldn't have this
963        // issue, as it would require a fully structured callee.
964        let callee_name = self
965            .id_to_name
966            .get(&callee.def_id().unwrap())
967            .copied()
968            .unwrap_or("");
969        let callee_name_id = *self
970            .cached_op_strings
971            .entry(callee_name)
972            .or_insert_with(|| {
973                let id = next_id(self.header);
974                self.debug_string_source.push(Instruction::new(
975                    Op::String,
976                    None,
977                    Some(id),
978                    vec![Operand::LiteralString(callee_name.to_string())],
979                ));
980                id
981            });
982        let mut mk_debuginfo_prefix_and_suffix = || {
983            // NOTE(eddyb) `OpExtInst`s have a result ID, even if unused, and
984            // it has to be unique (same goes for the other instructions below).
985            let instantiate_debuginfo = |this: &mut Self, inst: &Instruction| {
986                let mut inst = inst.clone();
987                if let Some(id) = &mut inst.result_id {
988                    *id = this.id();
989                }
990                inst
991            };
992            let custom_inst_to_inst = |this: &mut Self, inst: CustomInst<_>| {
993                Instruction::new(
994                    Op::ExtInst,
995                    Some(op_type_void_id),
996                    Some(this.id()),
997                    [
998                        Operand::IdRef(custom_ext_inst_set_import),
999                        Operand::LiteralExtInstInteger(inst.op() as u32),
1000                    ]
1001                    .into_iter()
1002                    .chain(inst.into_operands())
1003                    .collect(),
1004                )
1005            };
1006
1007            (
1008                (call_debug_src_loc_inst.map(|inst| instantiate_debuginfo(self, inst)))
1009                    .into_iter()
1010                    .chain([custom_inst_to_inst(
1011                        self,
1012                        CustomInst::PushInlinedCallFrame {
1013                            callee_name: Operand::IdRef(callee_name_id),
1014                        },
1015                    )]),
1016                [custom_inst_to_inst(self, CustomInst::PopInlinedCallFrame)],
1017            )
1018        };
1019
1020        let mut blocks = callee.blocks.clone();
1021        for block in &mut blocks {
1022            let mut terminator = block.instructions.pop().unwrap();
1023
1024            // HACK(eddyb) strip trailing debuginfo (as it can't impact terminators).
1025            while let Some(last) = block.instructions.last() {
1026                let can_remove = match last.class.opcode {
1027                    Op::Line | Op::NoLine => true,
1028                    Op::ExtInst => {
1029                        last.operands[0].unwrap_id_ref() == custom_ext_inst_set_import
1030                            && matches!(
1031                                CustomOp::decode_from_ext_inst(last),
1032                                CustomOp::SetDebugSrcLoc | CustomOp::ClearDebugSrcLoc
1033                            )
1034                    }
1035                    _ => false,
1036                };
1037                if can_remove {
1038                    block.instructions.pop();
1039                } else {
1040                    break;
1041                }
1042            }
1043
1044            if let Op::Return | Op::ReturnValue = terminator.class.opcode {
1045                if Op::ReturnValue == terminator.class.opcode {
1046                    let return_value = terminator.operands[0].id_ref_any().unwrap();
1047                    let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap();
1048                    call_result_phi.operands.extend([
1049                        Operand::IdRef(return_value),
1050                        Operand::IdRef(block.label_id().unwrap()),
1051                    ]);
1052                } else {
1053                    assert!(maybe_call_result_phi.is_none());
1054                }
1055                terminator =
1056                    Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
1057            }
1058
1059            let num_phis = block
1060                .instructions
1061                .iter()
1062                .take_while(|inst| inst.class.opcode == Op::Phi)
1063                .count();
1064
1065            // HACK(eddyb) avoid adding debuginfo to otherwise-empty blocks.
1066            if block.instructions.len() > num_phis {
1067                let (debuginfo_prefix, debuginfo_suffix) = mk_debuginfo_prefix_and_suffix();
1068                // Insert the prefix debuginfo instructions after `OpPhi`s,
1069                // which sadly can't be covered by them.
1070                block
1071                    .instructions
1072                    .splice(num_phis..num_phis, debuginfo_prefix);
1073                // Insert the suffix debuginfo instructions before the terminator,
1074                // which sadly can't be covered by them.
1075                block.instructions.extend(debuginfo_suffix);
1076            }
1077
1078            block.instructions.push(terminator);
1079        }
1080
1081        blocks
1082    }
1083
1084    fn insert_opvariables(&self, block: &mut Block, insts: impl IntoIterator<Item = Instruction>) {
1085        // HACK(eddyb) this isn't as efficient as it could be in theory, but it's
1086        // very important to make sure sure to never insert new instructions in
1087        // the middle of debuginfo (as it would be affected by it).
1088        let mut inlined_frames_depth = 0usize;
1089        let mut outermost_has_debug_src_loc = false;
1090        let mut last_debugless_var_insertion_point_candidate = None;
1091        for (i, inst) in block.instructions.iter().enumerate() {
1092            last_debugless_var_insertion_point_candidate =
1093                (inlined_frames_depth == 0 && !outermost_has_debug_src_loc).then_some(i);
1094
1095            let changed_has_debug_src_loc = match inst.class.opcode {
1096                Op::Line => true,
1097                Op::NoLine => false,
1098                Op::ExtInst
1099                    if inst.operands[0].unwrap_id_ref() == self.custom_ext_inst_set_import =>
1100                {
1101                    match CustomOp::decode_from_ext_inst(inst) {
1102                        CustomOp::SetDebugSrcLoc => true,
1103                        CustomOp::ClearDebugSrcLoc => false,
1104                        CustomOp::PushInlinedCallFrame => {
1105                            inlined_frames_depth += 1;
1106                            continue;
1107                        }
1108                        CustomOp::PopInlinedCallFrame => {
1109                            inlined_frames_depth = inlined_frames_depth.saturating_sub(1);
1110                            continue;
1111                        }
1112                        CustomOp::Abort => break,
1113                    }
1114                }
1115                Op::Variable => continue,
1116                _ => break,
1117            };
1118
1119            if inlined_frames_depth == 0 {
1120                outermost_has_debug_src_loc = changed_has_debug_src_loc;
1121            }
1122        }
1123
1124        // HACK(eddyb) fallback to inserting at the start, which should be correct.
1125        // FIXME(eddyb) some level of debuginfo repair could prevent needing this.
1126        let i = last_debugless_var_insertion_point_candidate.unwrap_or(0);
1127        block.instructions.splice(i..i, insts);
1128    }
1129}
1130
1131fn fuse_trivial_branches(function: &mut Function) {
1132    let all_preds = compute_preds(&function.blocks);
1133    'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
1134        // Don't fuse branches into blocks with `OpPhi`s.
1135        let any_phis = function.blocks[dest_block]
1136            .instructions
1137            .iter()
1138            .filter(|inst| {
1139                // These are the only instructions that are allowed before `OpPhi`.
1140                !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1141            })
1142            .take_while(|inst| inst.class.opcode == Op::Phi)
1143            .next()
1144            .is_some();
1145        if any_phis {
1146            continue;
1147        }
1148
1149        // if there's two trivial branches in a row, the middle one might get inlined before the
1150        // last one, so when processing the last one, skip through to the first one.
1151        let pred = loop {
1152            if preds.len() != 1 || preds[0] == dest_block {
1153                continue 'outer;
1154            }
1155            let pred = preds[0];
1156            if !function.blocks[pred].instructions.is_empty() {
1157                break pred;
1158            }
1159            preds = &all_preds[pred];
1160        };
1161        let pred_insts = &function.blocks[pred].instructions;
1162        if pred_insts.last().unwrap().class.opcode == Op::Branch {
1163            let mut dest_insts = mem::take(&mut function.blocks[dest_block].instructions);
1164            let pred_insts = &mut function.blocks[pred].instructions;
1165            pred_insts.pop(); // pop the branch
1166            pred_insts.append(&mut dest_insts);
1167
1168            // Adjust any `OpPhi`s in the targets of the original block, to refer
1169            // to the sole predecessor (the new source of those CFG edges).
1170            rewrite_phi_sources(
1171                function.blocks[dest_block].label_id().unwrap(),
1172                &mut function.blocks,
1173                pred,
1174            );
1175        }
1176    }
1177    function.blocks.retain(|b| !b.instructions.is_empty());
1178}
1179
1180fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
1181    let mut result = vec![vec![]; blocks.len()];
1182    for (source_idx, source) in blocks.iter().enumerate() {
1183        for dest_id in outgoing_edges(source) {
1184            let dest_idx = blocks
1185                .iter()
1186                .position(|b| b.label_id().unwrap() == dest_id)
1187                .unwrap();
1188            result[dest_idx].push(source_idx);
1189        }
1190    }
1191    result
1192}
1193
1194/// Helper for adjusting `OpPhi` source label IDs, when the terminator of the
1195/// `original_label_id`-labeled block got moved to `blocks[original_block_idx]`.
1196fn rewrite_phi_sources(original_label_id: Word, blocks: &mut [Block], new_block_idx: usize) {
1197    let new_label_id = blocks[new_block_idx].label_id().unwrap();
1198
1199    // HACK(eddyb) can't keep `blocks` borrowed, the loop needs mutable access.
1200    let target_ids: SmallVec<[_; 4]> = outgoing_edges(&blocks[new_block_idx]).collect();
1201
1202    for target_id in target_ids {
1203        let target_block = blocks
1204            .iter_mut()
1205            .find(|b| b.label_id().unwrap() == target_id)
1206            .unwrap();
1207        let phis = target_block
1208            .instructions
1209            .iter_mut()
1210            .filter(|inst| {
1211                // These are the only instructions that are allowed before `OpPhi`.
1212                !matches!(inst.class.opcode, Op::Line | Op::NoLine)
1213            })
1214            .take_while(|inst| inst.class.opcode == Op::Phi);
1215        for phi in phis {
1216            for value_and_source_id in phi.operands.chunks_mut(2) {
1217                let source_id = value_and_source_id[1].id_ref_any_mut().unwrap();
1218                if *source_id == original_label_id {
1219                    *source_id = new_label_id;
1220                    break;
1221                }
1222            }
1223        }
1224    }
1225}