rustc_codegen_spirv/linker/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub(crate) mod dce;
5mod destructure_composites;
6mod duplicates;
7mod entry_interface;
8mod import_export_link;
9mod inline;
10mod ipo;
11mod mem2reg;
12mod param_weakening;
13mod peephole_opts;
14mod simple_passes;
15mod specializer;
16mod spirt_passes;
17mod zombies;
18
19use std::borrow::Cow;
20
21use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
22use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
23use crate::custom_insts;
24use either::Either;
25use rspirv::binary::Assemble;
26use rspirv::dr::{Block, Module, ModuleHeader, Operand};
27use rspirv::spirv::{Op, StorageClass, Word};
28use rustc_data_structures::fx::FxHashMap;
29use rustc_errors::ErrorGuaranteed;
30use rustc_session::Session;
31use rustc_session::config::OutputFilenames;
32use std::cell::Cell;
33use std::collections::BTreeMap;
34use std::ffi::{OsStr, OsString};
35use std::path::PathBuf;
36
37pub type Result<T> = std::result::Result<T, ErrorGuaranteed>;
38
39#[derive(Default)]
40pub struct Options {
41    pub compact_ids: bool,
42    pub early_report_zombies: bool,
43    pub infer_storage_classes: bool,
44    pub structurize: bool,
45    pub spirt_passes: Vec<String>,
46
47    pub abort_strategy: Option<String>,
48    pub module_output_type: ModuleOutputType,
49
50    pub spirv_metadata: SpirvMetadata,
51
52    /// Whether to preserve `LinkageAttributes "..." Export` decorations,
53    /// even after resolving imports to exports.
54    ///
55    /// **Note**: currently only used for unit testing, and not exposed elsewhere.
56    pub keep_link_exports: bool,
57
58    // NOTE(eddyb) these are debugging options that used to be env vars
59    // (for more information see `docs/src/codegen-args.md`).
60    pub dump_post_merge: Option<PathBuf>,
61    pub dump_pre_inline: Option<PathBuf>,
62    pub dump_post_inline: Option<PathBuf>,
63    pub dump_post_split: Option<PathBuf>,
64    pub dump_spirt_passes: Option<PathBuf>,
65    pub spirt_strip_custom_debuginfo_from_dumps: bool,
66    pub spirt_keep_debug_sources_in_dumps: bool,
67    pub spirt_keep_unstructured_cfg_in_dumps: bool,
68    pub specializer_dump_instances: Option<PathBuf>,
69}
70
71pub enum LinkResult {
72    SingleModule(Box<Module>),
73    MultipleModules {
74        /// The "file stem" key is computed from the "entry name" in the value
75        /// (through `sanitize_filename`, replacing invalid chars with `-`),
76        /// but it's used as the map key because it *has to* be unique, even if
77        /// lossy sanitization could have erased distinctions between entry names.
78        file_stem_to_entry_name_and_module: BTreeMap<OsString, (String, Module)>,
79    },
80}
81
82fn id(header: &mut ModuleHeader) -> Word {
83    let result = header.bound;
84    header.bound += 1;
85    result
86}
87
88fn apply_rewrite_rules<'a>(
89    rewrite_rules: &FxHashMap<Word, Word>,
90    blocks: impl IntoIterator<Item = &'a mut Block>,
91) {
92    let all_ids_mut = blocks
93        .into_iter()
94        .flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
95        .flat_map(|inst| {
96            inst.result_id
97                .iter_mut()
98                .chain(inst.result_type.iter_mut())
99                .chain(
100                    inst.operands
101                        .iter_mut()
102                        .filter_map(|op| op.id_ref_any_mut()),
103                )
104        });
105    for id in all_ids_mut {
106        if let Some(&rewrite) = rewrite_rules.get(id) {
107            *id = rewrite;
108        }
109    }
110}
111
112fn get_names(module: &Module) -> FxHashMap<Word, &str> {
113    let entry_names = module
114        .entry_points
115        .iter()
116        .filter(|i| i.class.opcode == Op::EntryPoint)
117        .map(|i| {
118            (
119                i.operands[1].unwrap_id_ref(),
120                i.operands[2].unwrap_literal_string(),
121            )
122        });
123    let debug_names = module
124        .debug_names
125        .iter()
126        .filter(|i| i.class.opcode == Op::Name)
127        .map(|i| {
128            (
129                i.operands[0].unwrap_id_ref(),
130                i.operands[1].unwrap_literal_string(),
131            )
132        });
133    // items later on take priority
134    entry_names.chain(debug_names).collect()
135}
136
137fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
138    names.get(&id).map_or_else(
139        || Cow::Owned(format!("Unnamed function ID %{id}")),
140        |&s| Cow::Borrowed(s),
141    )
142}
143
144impl Options {
145    // FIXME(eddyb) using a method on this type seems a bit sketchy.
146    fn spirt_cleanup_for_dumping(&self, module: &mut spirt::Module) {
147        if self.spirt_strip_custom_debuginfo_from_dumps {
148            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
149        }
150        if !self.spirt_keep_debug_sources_in_dumps {
151            const DOTS: &str = "⋯";
152            let dots_interned_str = module.cx().intern(DOTS);
153            let spirt::ModuleDebugInfo::Spv(debuginfo) = &mut module.debug_info;
154            for sources in debuginfo.source_languages.values_mut() {
155                for file in sources.file_contents.values_mut() {
156                    *file = DOTS.into();
157                }
158                sources.file_contents.insert(
159                    dots_interned_str,
160                    "sources hidden, to show them use \
161                     `RUSTGPU_CODEGEN_ARGS=--spirt-keep-debug-sources-in-dumps`"
162                        .into(),
163                );
164            }
165        }
166    }
167}
168
169pub fn link(
170    sess: &Session,
171    mut inputs: Vec<Module>,
172    opts: &Options,
173    outputs: &OutputFilenames,
174    disambiguated_crate_name_for_dumps: &OsStr,
175) -> Result<LinkResult> {
176    // HACK(eddyb) this is defined here to allow SPIR-T pretty-printing to apply
177    // to SPIR-V being dumped, outside of e.g. `--dump-spirt-passes`.
178    // FIXME(eddyb) this isn't used everywhere, sadly - to find those, search
179    // elsewhere for `.assemble()` and/or `spirv_tools::binary::from_binary`.
180    let spv_module_to_spv_words_and_spirt_module = |spv_module: &Module| {
181        let spv_words;
182        let spv_bytes = {
183            let _timer = sess.timer("assemble-to-spv_bytes-for-spirt");
184            spv_words = spv_module.assemble();
185            // FIXME(eddyb) this is wastefully cloning all the bytes, but also
186            // `spirt::Module` should have a method that takes `Vec<u32>`.
187            spirv_tools::binary::from_binary(&spv_words).to_vec()
188        };
189
190        // FIXME(eddyb) should've really been "spirt::Module::lower_from_spv_bytes".
191        let lower_from_spv_timer = sess.timer("spirt::Module::lower_from_spv_file");
192        let cx = std::rc::Rc::new(spirt::Context::new());
193        crate::custom_insts::register_to_spirt_context(&cx);
194        (
195            spv_words,
196            spirt::Module::lower_from_spv_bytes(cx, spv_bytes),
197            // HACK(eddyb) this is only returned for `SpirtDumpGuard`.
198            lower_from_spv_timer,
199        )
200    };
201
202    // FIXME(eddyb) deduplicate with `SpirtDumpGuard`.
203    let dump_spv_and_spirt = |spv_module: &Module, dump_file_path_stem: PathBuf| {
204        let (spv_words, spirt_module_or_err, _) =
205            spv_module_to_spv_words_and_spirt_module(spv_module);
206        std::fs::write(
207            dump_file_path_stem.with_extension("spv"),
208            spirv_tools::binary::from_binary(&spv_words),
209        )
210        .unwrap();
211
212        // FIXME(eddyb) reify SPIR-V -> SPIR-T errors so they're easier to debug.
213        if let Ok(mut module) = spirt_module_or_err {
214            // HACK(eddyb) avoid pretty-printing massive amounts of unused SPIR-T.
215            spirt::passes::link::minimize_exports(&mut module, |export_key| {
216                matches!(export_key, spirt::ExportKey::SpvEntryPoint { .. })
217            });
218
219            opts.spirt_cleanup_for_dumping(&mut module);
220
221            let pretty = spirt::print::Plan::for_module(&module).pretty_print();
222
223            // FIXME(eddyb) don't allocate whole `String`s here.
224            std::fs::write(
225                dump_file_path_stem.with_extension("spirt"),
226                pretty.to_string(),
227            )
228            .unwrap();
229            std::fs::write(
230                dump_file_path_stem.with_extension("spirt.html"),
231                pretty
232                    .render_to_html()
233                    .with_dark_mode_support()
234                    .to_html_doc(),
235            )
236            .unwrap();
237        }
238    };
239
240    let mut output = {
241        let _timer = sess.timer("link_merge");
242        // shift all the ids
243        let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
244        let version = inputs[0].header.as_ref().unwrap().version();
245
246        for module in inputs.iter_mut().skip(1) {
247            simple_passes::shift_ids(module, bound);
248            bound += module.header.as_ref().unwrap().bound - 1;
249            let this_version = module.header.as_ref().unwrap().version();
250            if version != this_version {
251                return Err(sess.dcx().err(format!(
252                    "cannot link two modules with different SPIR-V versions: v{}.{} and v{}.{}",
253                    version.0, version.1, this_version.0, this_version.1
254                )));
255            }
256        }
257
258        // merge the binaries
259        let mut output = crate::link::with_rspirv_loader(|loader| {
260            for module in inputs {
261                for inst in module.all_inst_iter() {
262                    use rspirv::binary::ParseAction;
263                    match loader.consume_instruction(inst.clone()) {
264                        ParseAction::Continue => {}
265                        ParseAction::Stop => unreachable!(),
266                        ParseAction::Error(err) => return Err(err),
267                    }
268                }
269            }
270            Ok(())
271        })
272        .unwrap();
273
274        let mut header = ModuleHeader::new(bound + 1);
275        header.set_version(version.0, version.1);
276        header.generator = 0x001B_0000;
277        output.header = Some(header);
278        output
279    };
280
281    if let Some(dir) = &opts.dump_post_merge {
282        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
283    }
284
285    // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
286    {
287        let _timer = sess.timer("link_remove_duplicates");
288        duplicates::remove_duplicate_extensions(&mut output);
289        duplicates::remove_duplicate_capabilities(&mut output);
290        duplicates::remove_duplicate_ext_inst_imports(&mut output);
291        duplicates::remove_duplicate_types(&mut output);
292        // jb-todo: strip identical OpDecoration / OpDecorationGroups
293    }
294
295    // find import / export pairs
296    {
297        let _timer = sess.timer("link_find_pairs");
298        import_export_link::run(opts, sess, &mut output)?;
299    }
300
301    {
302        let _timer = sess.timer("link_dce-post-link");
303        dce::dce(&mut output);
304    }
305
306    {
307        let _timer = sess.timer("link_fragment_inst_check");
308        simple_passes::check_fragment_insts(sess, &output)?;
309    }
310
311    // HACK(eddyb) this has to run before the `report_zombies` pass, so that
312    // any zombies that are passed as call arguments, but eventually unused,
313    // won't be (incorrectly) considered used.
314    {
315        let _timer = sess.timer("link_remove_unused_params");
316        output = param_weakening::remove_unused_params(output);
317    }
318
319    if opts.early_report_zombies {
320        let _timer = sess.timer("link_report_zombies");
321        zombies::report_zombies(sess, &output)?;
322    }
323
324    if opts.infer_storage_classes {
325        let _timer = sess.timer("specialize_generic_storage_class");
326        // HACK(eddyb) `specializer` requires functions' blocks to be in RPO order
327        // (i.e. `block_ordering_pass`) - this could be relaxed by using RPO visit
328        // inside `specializer`, but this is easier.
329        for func in &mut output.functions {
330            simple_passes::block_ordering_pass(func);
331        }
332        output = specializer::specialize(
333            opts,
334            output,
335            specializer::SimpleSpecialization {
336                specialize_operand: |operand| {
337                    matches!(operand, Operand::StorageClass(StorageClass::Generic))
338                },
339
340                // NOTE(eddyb) this can be anything that is guaranteed to pass
341                // validation - there are no constraints so this is either some
342                // unused pointer, or perhaps one created using `OpConstantNull`
343                // and simply never mixed with pointers that have a storage class.
344                // It would be nice to use `Generic` itself here so that we leave
345                // some kind of indication of it being unconstrained, but `Generic`
346                // requires additional capabilities, so we use `Function` instead.
347                // TODO(eddyb) investigate whether this can end up in a pointer
348                // type that's the value of a module-scoped variable, and whether
349                // `Function` is actually invalid! (may need `Private`)
350                concrete_fallback: Operand::StorageClass(StorageClass::Function),
351            },
352        );
353    }
354
355    // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too!
356    {
357        {
358            let _timer = sess.timer("link_dce-before-inlining");
359            dce::dce(&mut output);
360        }
361
362        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-before-inlining");
363        let mut pointer_to_pointee = FxHashMap::default();
364        let mut constants = FxHashMap::default();
365        let mut u32 = None;
366        for inst in &output.types_global_values {
367            match inst.class.opcode {
368                Op::TypePointer => {
369                    pointer_to_pointee
370                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
371                }
372                Op::TypeInt
373                    if inst.operands[0].unwrap_literal_bit32() == 32
374                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
375                {
376                    assert!(u32.is_none());
377                    u32 = Some(inst.result_id.unwrap());
378                }
379                Op::Constant if u32.is_some() && inst.result_type == u32 => {
380                    let value = inst.operands[0].unwrap_literal_bit32();
381                    constants.insert(inst.result_id.unwrap(), value);
382                }
383                _ => {}
384            }
385        }
386        for func in &mut output.functions {
387            simple_passes::block_ordering_pass(func);
388            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
389            mem2reg::mem2reg(
390                output.header.as_mut().unwrap(),
391                &mut output.types_global_values,
392                &pointer_to_pointee,
393                &constants,
394                func,
395            );
396            destructure_composites::destructure_composites(func);
397        }
398    }
399
400    {
401        let _timer =
402            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-before-inlining");
403        dce::dce(&mut output);
404        duplicates::remove_duplicate_debuginfo(&mut output);
405    }
406
407    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
408    if let Some(dir) = &opts.dump_pre_inline {
409        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
410    }
411
412    {
413        let _timer = sess.timer("link_inline");
414        inline::inline(sess, &mut output)?;
415    }
416
417    {
418        let _timer = sess.timer("link_dce-after-inlining");
419        dce::dce(&mut output);
420    }
421
422    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
423    if let Some(dir) = &opts.dump_post_inline {
424        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
425    }
426
427    {
428        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-after-inlining");
429        let mut pointer_to_pointee = FxHashMap::default();
430        let mut constants = FxHashMap::default();
431        let mut u32 = None;
432        for inst in &output.types_global_values {
433            match inst.class.opcode {
434                Op::TypePointer => {
435                    pointer_to_pointee
436                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
437                }
438                Op::TypeInt
439                    if inst.operands[0].unwrap_literal_bit32() == 32
440                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
441                {
442                    assert!(u32.is_none());
443                    u32 = Some(inst.result_id.unwrap());
444                }
445                Op::Constant if u32.is_some() && inst.result_type == u32 => {
446                    let value = inst.operands[0].unwrap_literal_bit32();
447                    constants.insert(inst.result_id.unwrap(), value);
448                }
449                _ => {}
450            }
451        }
452        for func in &mut output.functions {
453            simple_passes::block_ordering_pass(func);
454            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
455            mem2reg::mem2reg(
456                output.header.as_mut().unwrap(),
457                &mut output.types_global_values,
458                &pointer_to_pointee,
459                &constants,
460                func,
461            );
462            destructure_composites::destructure_composites(func);
463        }
464    }
465
466    {
467        let _timer =
468            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-after-inlining");
469        dce::dce(&mut output);
470        duplicates::remove_duplicate_debuginfo(&mut output);
471    }
472
473    {
474        let _timer = sess.timer("link_remove_non_uniform");
475        simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
476    }
477
478    // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
479    {
480        let (spv_words, module_or_err, lower_from_spv_timer) =
481            spv_module_to_spv_words_and_spirt_module(&output);
482        let module = &mut module_or_err.map_err(|e| {
483            let spv_path = outputs.temp_path_for_diagnostic("spirt-lower-from-spv-input.spv");
484
485            let was_saved_msg =
486                match std::fs::write(&spv_path, spirv_tools::binary::from_binary(&spv_words)) {
487                    Ok(()) => format!("was saved to {}", spv_path.display()),
488                    Err(e) => format!("could not be saved: {e}"),
489                };
490
491            sess.dcx()
492                .struct_err(format!("{e}"))
493                .with_note("while lowering SPIR-V module to SPIR-T (spirt::spv::lower)")
494                .with_note(format!("input SPIR-V module {was_saved_msg}"))
495                .emit()
496        })?;
497
498        let mut dump_guard = SpirtDumpGuard {
499            sess,
500            linker_options: opts,
501            outputs,
502            disambiguated_crate_name_for_dumps,
503
504            module,
505            per_pass_module_for_dumping: vec![],
506            in_progress_pass_name: Cell::new(Some("lower_from_spv")),
507            any_spirt_bugs: false,
508        };
509        let module = &mut *dump_guard.module;
510        // FIXME(eddyb) consider returning a `Drop`-implementing type instead?
511        let before_pass = |pass_name| {
512            let outer_pass_name = dump_guard.in_progress_pass_name.replace(Some(pass_name));
513
514            // FIXME(eddyb) could it make sense to allow these to nest?
515            assert_eq!(outer_pass_name, None);
516
517            sess.timer(pass_name)
518        };
519        let mut after_pass = |module: Option<&spirt::Module>, timer| {
520            drop(timer);
521            let pass_name = dump_guard.in_progress_pass_name.take().unwrap();
522            if let Some(module) = module
523                && opts.dump_spirt_passes.is_some()
524            {
525                dump_guard
526                    .per_pass_module_for_dumping
527                    .push((pass_name.into(), module.clone()));
528            }
529        };
530        // HACK(eddyb) don't dump the unstructured state if not requested, as
531        // after SPIR-T 0.4.0 it's extremely verbose (due to def-use hermeticity).
532        after_pass(
533            (opts.spirt_keep_unstructured_cfg_in_dumps || !opts.structurize).then_some(&*module),
534            lower_from_spv_timer,
535        );
536
537        // NOTE(eddyb) this *must* run on unstructured CFGs, to do its job.
538        // FIXME(eddyb) no longer relying on structurization, try porting this
539        // to replace custom aborts in `Block`s and inject `ExitInvocation`s
540        // after them (truncating the `Block` and/or parent region if necessary).
541        {
542            let timer = before_pass(
543                "spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points",
544            );
545            spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points(opts, module);
546            after_pass(None, timer);
547        }
548
549        if opts.structurize {
550            let timer = before_pass("spirt::legalize::structurize_func_cfgs");
551            spirt::passes::legalize::structurize_func_cfgs(module);
552            after_pass(Some(module), timer);
553        }
554
555        if !opts.spirt_passes.is_empty() {
556            // FIXME(eddyb) why does this focus on functions, it could just be module passes??
557            spirt_passes::run_func_passes(
558                module,
559                &opts.spirt_passes,
560                |name, _module| before_pass(name),
561                &mut after_pass,
562            );
563        }
564
565        {
566            let timer = before_pass("spirt_passes::explicit_layout::erase_when_invalid");
567            spirt_passes::explicit_layout::erase_when_invalid(module);
568            after_pass(Some(module), timer);
569        }
570
571        {
572            let timer = before_pass("spirt_passes::validate");
573            spirt_passes::validate::validate(module);
574            after_pass(Some(module), timer);
575        }
576
577        {
578            let timer = before_pass("spirt_passes::diagnostics::report_diagnostics");
579            spirt_passes::diagnostics::report_diagnostics(sess, opts, module).map_err(
580                |spirt_passes::diagnostics::ReportedDiagnostics {
581                     rustc_errors_guarantee,
582                     any_errors_were_spirt_bugs,
583                 }| {
584                    dump_guard.any_spirt_bugs |= any_errors_were_spirt_bugs;
585                    rustc_errors_guarantee
586                },
587            )?;
588            after_pass(None, timer);
589        }
590
591        // Replace our custom debuginfo instructions just before lifting to SPIR-V.
592        {
593            let timer = before_pass("spirt_passes::debuginfo::convert_custom_debuginfo_to_spv");
594            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
595            after_pass(None, timer);
596        }
597
598        let spv_words = {
599            let timer = before_pass("spirt::Module::lift_to_spv_module_emitter");
600            let spv_words = module.lift_to_spv_module_emitter().unwrap().words;
601            after_pass(None, timer);
602            spv_words
603        };
604        // FIXME(eddyb) dump both SPIR-T and `spv_words` if there's an error here.
605        output = {
606            let _timer = sess.timer("parse-spv_words-from-spirt");
607            crate::link::with_rspirv_loader(|loader| {
608                rspirv::binary::parse_words(&spv_words, loader)
609            })
610            .unwrap()
611        };
612    }
613
614    // Ensure that no references remain, to our custom "extended instruction set".
615    for inst in &output.ext_inst_imports {
616        assert_eq!(inst.class.opcode, Op::ExtInstImport);
617        let ext_inst_set = inst.operands[0].unwrap_literal_string();
618        if ext_inst_set.starts_with(custom_insts::CUSTOM_EXT_INST_SET_PREFIX) {
619            let expected = &custom_insts::CUSTOM_EXT_INST_SET[..];
620            if ext_inst_set == expected {
621                return Err(sess.dcx().err(format!(
622                    "`OpExtInstImport {ext_inst_set:?}` should not have been \
623                         left around after SPIR-T passes"
624                )));
625            } else {
626                return Err(sess.dcx().err(format!(
627                    "unsupported `OpExtInstImport {ext_inst_set:?}`
628                     (expected {expected:?} name - version mismatch?)"
629                )));
630            }
631        }
632    }
633
634    // FIXME(eddyb) rewrite these passes to SPIR-T ones, so we don't have to
635    // parse the output of `spirt::spv::lift` back into `rspirv` - also, for
636    // multi-module, it's much simpler with SPIR-T, just replace `module.exports`
637    // with a single-entry map, run `spirt::spv::lift` (or even `spirt::print`)
638    // on `module`, then put back the full original `module.exports` map.
639    {
640        let _timer = sess.timer("peephole_opts");
641        let types = peephole_opts::collect_types(&output);
642        for func in &mut output.functions {
643            peephole_opts::composite_construct(&types, func);
644            peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
645            peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
646        }
647    }
648
649    {
650        let _timer = sess.timer("link_remove_unused_type_capabilities");
651        simple_passes::remove_unused_type_capabilities(&mut output);
652    }
653
654    {
655        let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
656        entry_interface::gather_all_interface_vars_from_uses(&mut output);
657    }
658
659    if opts.spirv_metadata == SpirvMetadata::NameVariables {
660        let _timer = sess.timer("link_name_variables");
661        simple_passes::name_variables_pass(&mut output);
662    }
663
664    {
665        let _timer = sess.timer("link_sort_globals");
666        simple_passes::sort_globals(&mut output);
667    }
668
669    let mut output = if opts.module_output_type == ModuleOutputType::Multiple {
670        let mut file_stem_to_entry_name_and_module = BTreeMap::new();
671        for (i, entry) in output.entry_points.iter().enumerate() {
672            let mut module = output.clone();
673            module.entry_points.clear();
674            module.entry_points.push(entry.clone());
675            let entry_name = entry.operands[2].unwrap_literal_string().to_string();
676            let mut file_stem = OsString::from(
677                sanitize_filename::sanitize_with_options(
678                    &entry_name,
679                    sanitize_filename::Options {
680                        replacement: "-",
681                        ..Default::default()
682                    },
683                )
684                .replace("--", "-"),
685            );
686            // It's always possible to find an unambiguous `file_stem`, but it
687            // may take two tries (or more, in bizzare/adversarial cases).
688            let mut disambiguator = Some(i);
689            loop {
690                use std::collections::btree_map::Entry;
691                match file_stem_to_entry_name_and_module.entry(file_stem) {
692                    Entry::Vacant(entry) => {
693                        entry.insert((entry_name, module));
694                        break;
695                    }
696                    // FIXME(eddyb) false positive: `file_stem` was moved out of,
697                    // so assigning it is necessary, but clippy doesn't know that.
698                    #[allow(clippy::assigning_clones)]
699                    Entry::Occupied(entry) => {
700                        // FIXME(eddyb) there's no way to access the owned key
701                        // passed to `BTreeMap::entry` from `OccupiedEntry`.
702                        file_stem = entry.key().clone();
703                        file_stem.push(".");
704                        match disambiguator.take() {
705                            Some(d) => file_stem.push(d.to_string()),
706                            None => file_stem.push("next"),
707                        }
708                    }
709                }
710            }
711        }
712        LinkResult::MultipleModules {
713            file_stem_to_entry_name_and_module,
714        }
715    } else {
716        LinkResult::SingleModule(Box::new(output))
717    };
718
719    let output_module_iter = match &mut output {
720        LinkResult::SingleModule(m) => Either::Left(std::iter::once((None, &mut **m))),
721        LinkResult::MultipleModules {
722            file_stem_to_entry_name_and_module,
723        } => Either::Right(
724            file_stem_to_entry_name_and_module
725                .iter_mut()
726                .map(|(file_stem, (_, m))| (Some(file_stem), m)),
727        ),
728    };
729    for (file_stem, output) in output_module_iter {
730        // Run DCE again, even if module_output_type == ModuleOutputType::Multiple - the first DCE ran before
731        // structurization and mem2reg (for perf reasons), and mem2reg may remove references to
732        // invalid types, so we need to DCE again.
733        {
734            let _timer = sess.timer("link_dce-post-split");
735            dce::dce(output);
736        }
737
738        // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
739        if let Some(dir) = &opts.dump_post_split {
740            let mut file_name = disambiguated_crate_name_for_dumps.to_os_string();
741            if let Some(file_stem) = file_stem {
742                file_name.push(".");
743                file_name.push(file_stem);
744            }
745
746            dump_spv_and_spirt(output, dir.join(file_name));
747        }
748
749        {
750            let _timer = sess.timer("link_remove_duplicate_debuginfo");
751            duplicates::remove_duplicate_debuginfo(output);
752        }
753
754        if opts.compact_ids {
755            let _timer = sess.timer("link_compact_ids");
756            // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43
757            output.header.as_mut().unwrap().bound = simple_passes::compact_ids(output);
758        };
759
760        // FIXME(eddyb) convert these into actual `OpLine`s with a SPIR-T pass,
761        // but that'd require keeping the modules in SPIR-T form (once lowered),
762        // and never loading them back into `rspirv` once lifted back to SPIR-V.
763        SrcLocDecoration::remove_all(output);
764
765        // FIXME(eddyb) might make more sense to rewrite these away on SPIR-T.
766        ZombieDecoration::remove_all(output);
767    }
768
769    Ok(output)
770}
771
772/// Helper for dumping SPIR-T on drop, which allows panics to also dump,
773/// not just successful compilation (i.e. via `--dump-spirt-passes`).
774struct SpirtDumpGuard<'a> {
775    sess: &'a Session,
776    linker_options: &'a Options,
777    outputs: &'a OutputFilenames,
778    disambiguated_crate_name_for_dumps: &'a OsStr,
779
780    module: &'a mut spirt::Module,
781    per_pass_module_for_dumping: Vec<(Cow<'static, str>, spirt::Module)>,
782    in_progress_pass_name: Cell<Option<&'static str>>,
783    any_spirt_bugs: bool,
784}
785
786impl Drop for SpirtDumpGuard<'_> {
787    fn drop(&mut self) {
788        if std::thread::panicking() {
789            self.any_spirt_bugs = true;
790
791            // HACK(eddyb) the active pass panicked, make sure to include its
792            // (potentially corrupted) state, which will hopefully be printed
793            // later below (with protection against panicking during printing).
794            if let Some(pass_name) = self.in_progress_pass_name.get() {
795                self.per_pass_module_for_dumping.push((
796                    format!("{pass_name} [PANICKED]").into(),
797                    self.module.clone(),
798                ));
799            }
800        }
801
802        let mut dump_spirt_file_path =
803            self.linker_options
804                .dump_spirt_passes
805                .as_ref()
806                .map(|dump_dir| {
807                    dump_dir
808                        .join(self.disambiguated_crate_name_for_dumps)
809                        .with_extension("spirt")
810                });
811
812        // FIXME(eddyb) this won't allow seeing the individual passes, but it's
813        // better than nothing (theoretically the whole "SPIR-T pipeline" could
814        // be put in a loop so that everything is redone with per-pass tracking,
815        // but that requires keeping around e.g. the initial SPIR-V for longer,
816        // and probably invoking the "SPIR-T pipeline" here, as looping is hard).
817        if self.any_spirt_bugs && dump_spirt_file_path.is_none() {
818            if self.per_pass_module_for_dumping.is_empty() {
819                self.per_pass_module_for_dumping
820                    .push(("".into(), self.module.clone()));
821            }
822            dump_spirt_file_path = Some(self.outputs.temp_path_for_diagnostic("spirt"));
823        }
824
825        let Some(dump_spirt_file_path) = &dump_spirt_file_path else {
826            return;
827        };
828
829        for (_, module) in &mut self.per_pass_module_for_dumping {
830            // FIXME(eddyb) consider catching panics in this?
831            self.linker_options.spirt_cleanup_for_dumping(module);
832        }
833
834        let cx = self.module.cx();
835        let versions = self
836            .per_pass_module_for_dumping
837            .iter()
838            .map(|(pass_name, module)| (format!("after {pass_name}"), module));
839
840        let mut panicked_printing_after_passes = None;
841        for truncate_version_count in (1..=versions.len()).rev() {
842            // FIXME(eddyb) tell the user to use `--dump-spirt-passes` if that
843            // wasn't active but a panic happens - on top of that, it may need
844            // quieting the panic handler, likely controlled by a `thread_local!`
845            // (while the panic handler is global), and that would also be useful
846            // for collecting a panic message (assuming any of this is worth it).
847            // HACK(eddyb) for now, keeping the panic handler works out, as the
848            // panic messages would at least be seen by the user.
849            let printed_or_panicked =
850                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
851                    let pretty = spirt::print::Plan::for_versions(
852                        &cx,
853                        versions.clone().take(truncate_version_count),
854                    )
855                    .pretty_print();
856
857                    // FIXME(eddyb) don't allocate whole `String`s here.
858                    std::fs::write(dump_spirt_file_path, pretty.to_string()).unwrap();
859                    std::fs::write(
860                        dump_spirt_file_path.with_extension("spirt.html"),
861                        pretty
862                            .render_to_html()
863                            .with_dark_mode_support()
864                            .to_html_doc(),
865                    )
866                    .unwrap();
867                }));
868            match printed_or_panicked {
869                Ok(()) => {
870                    if truncate_version_count != versions.len() {
871                        panicked_printing_after_passes = Some(
872                            self.per_pass_module_for_dumping[truncate_version_count..]
873                                .iter()
874                                .map(|(pass_name, _)| format!("`{pass_name}`"))
875                                .collect::<Vec<_>>()
876                                .join(", "),
877                        );
878                    }
879                    break;
880                }
881                Err(panic) => {
882                    if truncate_version_count == 1 {
883                        std::panic::resume_unwind(panic);
884                    }
885                }
886            }
887        }
888        if self.any_spirt_bugs || panicked_printing_after_passes.is_some() {
889            let mut note = self.sess.dcx().struct_note("SPIR-T bugs were encountered");
890            if let Some(pass_names) = panicked_printing_after_passes {
891                note.warn(format!(
892                    "SPIR-T pretty-printing panicked after: {pass_names}"
893                ));
894            }
895            note.help(format!(
896                "pretty-printed SPIR-T was saved to {}.html",
897                dump_spirt_file_path.display()
898            ));
899            if self.linker_options.dump_spirt_passes.is_none() {
900                note.help("re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` for more details");
901            }
902            note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues");
903            note.emit();
904        }
905    }
906}