rustc_codegen_spirv/linker/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub(crate) mod dce;
5mod destructure_composites;
6mod duplicates;
7mod entry_interface;
8mod import_export_link;
9mod inline;
10mod ipo;
11mod mem2reg;
12mod param_weakening;
13mod peephole_opts;
14mod simple_passes;
15mod specializer;
16mod spirt_passes;
17mod zombies;
18
19use std::borrow::Cow;
20
21use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
22use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
23use crate::custom_insts;
24use either::Either;
25use rspirv::binary::Assemble;
26use rspirv::dr::{Block, Module, ModuleHeader, Operand};
27use rspirv::spirv::{Op, StorageClass, Word};
28use rustc_data_structures::fx::FxHashMap;
29use rustc_errors::ErrorGuaranteed;
30use rustc_session::Session;
31use rustc_session::config::OutputFilenames;
32use std::cell::Cell;
33use std::collections::BTreeMap;
34use std::ffi::{OsStr, OsString};
35use std::path::PathBuf;
36
37pub type Result<T> = std::result::Result<T, ErrorGuaranteed>;
38
39#[derive(Default)]
40pub struct Options {
41    pub compact_ids: bool,
42    pub early_report_zombies: bool,
43    pub infer_storage_classes: bool,
44    pub structurize: bool,
45    pub preserve_bindings: bool,
46    pub spirt_passes: Vec<String>,
47
48    pub abort_strategy: Option<String>,
49    pub module_output_type: ModuleOutputType,
50
51    pub spirv_metadata: SpirvMetadata,
52
53    /// Whether to preserve `LinkageAttributes "..." Export` decorations,
54    /// even after resolving imports to exports.
55    ///
56    /// **Note**: currently only used for unit testing, and not exposed elsewhere.
57    pub keep_link_exports: bool,
58
59    // NOTE(eddyb) these are debugging options that used to be env vars
60    // (for more information see `docs/src/codegen-args.md`).
61    pub dump_post_merge: Option<PathBuf>,
62    pub dump_pre_inline: Option<PathBuf>,
63    pub dump_post_inline: Option<PathBuf>,
64    pub dump_post_split: Option<PathBuf>,
65    pub dump_spirt_passes: Option<PathBuf>,
66    pub spirt_strip_custom_debuginfo_from_dumps: bool,
67    pub spirt_keep_debug_sources_in_dumps: bool,
68    pub spirt_keep_unstructured_cfg_in_dumps: bool,
69    pub specializer_dump_instances: Option<PathBuf>,
70}
71
72pub enum LinkResult {
73    SingleModule(Box<Module>),
74    MultipleModules {
75        /// The "file stem" key is computed from the "entry name" in the value
76        /// (through `sanitize_filename`, replacing invalid chars with `-`),
77        /// but it's used as the map key because it *has to* be unique, even if
78        /// lossy sanitization could have erased distinctions between entry names.
79        file_stem_to_entry_name_and_module: BTreeMap<OsString, (String, Module)>,
80    },
81}
82
83fn id(header: &mut ModuleHeader) -> Word {
84    let result = header.bound;
85    header.bound += 1;
86    result
87}
88
89fn apply_rewrite_rules<'a>(
90    rewrite_rules: &FxHashMap<Word, Word>,
91    blocks: impl IntoIterator<Item = &'a mut Block>,
92) {
93    let all_ids_mut = blocks
94        .into_iter()
95        .flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
96        .flat_map(|inst| {
97            inst.result_id
98                .iter_mut()
99                .chain(inst.result_type.iter_mut())
100                .chain(
101                    inst.operands
102                        .iter_mut()
103                        .filter_map(|op| op.id_ref_any_mut()),
104                )
105        });
106    for id in all_ids_mut {
107        if let Some(&rewrite) = rewrite_rules.get(id) {
108            *id = rewrite;
109        }
110    }
111}
112
113fn get_names(module: &Module) -> FxHashMap<Word, &str> {
114    let entry_names = module
115        .entry_points
116        .iter()
117        .filter(|i| i.class.opcode == Op::EntryPoint)
118        .map(|i| {
119            (
120                i.operands[1].unwrap_id_ref(),
121                i.operands[2].unwrap_literal_string(),
122            )
123        });
124    let debug_names = module
125        .debug_names
126        .iter()
127        .filter(|i| i.class.opcode == Op::Name)
128        .map(|i| {
129            (
130                i.operands[0].unwrap_id_ref(),
131                i.operands[1].unwrap_literal_string(),
132            )
133        });
134    // items later on take priority
135    entry_names.chain(debug_names).collect()
136}
137
138fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
139    names.get(&id).map_or_else(
140        || Cow::Owned(format!("Unnamed function ID %{id}")),
141        |&s| Cow::Borrowed(s),
142    )
143}
144
145impl Options {
146    // FIXME(eddyb) using a method on this type seems a bit sketchy.
147    fn spirt_cleanup_for_dumping(&self, module: &mut spirt::Module) {
148        if self.spirt_strip_custom_debuginfo_from_dumps {
149            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
150        }
151        if !self.spirt_keep_debug_sources_in_dumps {
152            const DOTS: &str = "⋯";
153            let dots_interned_str = module.cx().intern(DOTS);
154            let spirt::ModuleDebugInfo::Spv(debuginfo) = &mut module.debug_info;
155            for sources in debuginfo.source_languages.values_mut() {
156                for file in sources.file_contents.values_mut() {
157                    *file = DOTS.into();
158                }
159                sources.file_contents.insert(
160                    dots_interned_str,
161                    "sources hidden, to show them use \
162                     `RUSTGPU_CODEGEN_ARGS=--spirt-keep-debug-sources-in-dumps`"
163                        .into(),
164                );
165            }
166        }
167    }
168}
169
170pub fn link(
171    sess: &Session,
172    mut inputs: Vec<Module>,
173    opts: &Options,
174    outputs: &OutputFilenames,
175    disambiguated_crate_name_for_dumps: &OsStr,
176) -> Result<LinkResult> {
177    // HACK(eddyb) this is defined here to allow SPIR-T pretty-printing to apply
178    // to SPIR-V being dumped, outside of e.g. `--dump-spirt-passes`.
179    // FIXME(eddyb) this isn't used everywhere, sadly - to find those, search
180    // elsewhere for `.assemble()` and/or `spirv_tools::binary::from_binary`.
181    let spv_module_to_spv_words_and_spirt_module = |spv_module: &Module| {
182        let spv_words;
183        let spv_bytes = {
184            let _timer = sess.timer("assemble-to-spv_bytes-for-spirt");
185            spv_words = spv_module.assemble();
186            // FIXME(eddyb) this is wastefully cloning all the bytes, but also
187            // `spirt::Module` should have a method that takes `Vec<u32>`.
188            spirv_tools::binary::from_binary(&spv_words).to_vec()
189        };
190
191        // FIXME(eddyb) should've really been "spirt::Module::lower_from_spv_bytes".
192        let lower_from_spv_timer = sess.timer("spirt::Module::lower_from_spv_file");
193        let cx = std::rc::Rc::new(spirt::Context::new());
194        crate::custom_insts::register_to_spirt_context(&cx);
195        (
196            spv_words,
197            spirt::Module::lower_from_spv_bytes(cx, spv_bytes),
198            // HACK(eddyb) this is only returned for `SpirtDumpGuard`.
199            lower_from_spv_timer,
200        )
201    };
202
203    // FIXME(eddyb) deduplicate with `SpirtDumpGuard`.
204    let dump_spv_and_spirt = |spv_module: &Module, dump_file_path_stem: PathBuf| {
205        let (spv_words, spirt_module_or_err, _) =
206            spv_module_to_spv_words_and_spirt_module(spv_module);
207        std::fs::write(
208            dump_file_path_stem.with_extension("spv"),
209            spirv_tools::binary::from_binary(&spv_words),
210        )
211        .unwrap();
212
213        // FIXME(eddyb) reify SPIR-V -> SPIR-T errors so they're easier to debug.
214        if let Ok(mut module) = spirt_module_or_err {
215            // HACK(eddyb) avoid pretty-printing massive amounts of unused SPIR-T.
216            spirt::passes::link::minimize_exports(&mut module, |export_key| {
217                matches!(export_key, spirt::ExportKey::SpvEntryPoint { .. })
218            });
219
220            opts.spirt_cleanup_for_dumping(&mut module);
221
222            let pretty = spirt::print::Plan::for_module(&module).pretty_print();
223
224            // FIXME(eddyb) don't allocate whole `String`s here.
225            std::fs::write(
226                dump_file_path_stem.with_extension("spirt"),
227                pretty.to_string(),
228            )
229            .unwrap();
230            std::fs::write(
231                dump_file_path_stem.with_extension("spirt.html"),
232                pretty
233                    .render_to_html()
234                    .with_dark_mode_support()
235                    .to_html_doc(),
236            )
237            .unwrap();
238        }
239    };
240
241    let mut output = {
242        let _timer = sess.timer("link_merge");
243        // shift all the ids
244        let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
245        let version = inputs[0].header.as_ref().unwrap().version();
246
247        for module in inputs.iter_mut().skip(1) {
248            simple_passes::shift_ids(module, bound);
249            bound += module.header.as_ref().unwrap().bound - 1;
250            let this_version = module.header.as_ref().unwrap().version();
251            if version != this_version {
252                return Err(sess.dcx().err(format!(
253                    "cannot link two modules with different SPIR-V versions: v{}.{} and v{}.{}",
254                    version.0, version.1, this_version.0, this_version.1
255                )));
256            }
257        }
258
259        // merge the binaries
260        let mut output = crate::link::with_rspirv_loader(|loader| {
261            for module in inputs {
262                for inst in module.all_inst_iter() {
263                    use rspirv::binary::ParseAction;
264                    match loader.consume_instruction(inst.clone()) {
265                        ParseAction::Continue => {}
266                        ParseAction::Stop => unreachable!(),
267                        ParseAction::Error(err) => return Err(err),
268                    }
269                }
270            }
271            Ok(())
272        })
273        .unwrap();
274
275        let mut header = ModuleHeader::new(bound + 1);
276        header.set_version(version.0, version.1);
277        header.generator = 0x001B_0000;
278        output.header = Some(header);
279        output
280    };
281
282    if let Some(dir) = &opts.dump_post_merge {
283        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
284    }
285
286    // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
287    {
288        let _timer = sess.timer("link_remove_duplicates");
289        duplicates::remove_duplicate_extensions(&mut output);
290        duplicates::remove_duplicate_capabilities(&mut output);
291        duplicates::remove_duplicate_ext_inst_imports(&mut output);
292        duplicates::remove_duplicate_types(&mut output);
293        // jb-todo: strip identical OpDecoration / OpDecorationGroups
294    }
295
296    // find import / export pairs
297    {
298        let _timer = sess.timer("link_find_pairs");
299        import_export_link::run(opts, sess, &mut output)?;
300    }
301
302    {
303        let _timer = sess.timer("link_dce-post-link");
304        dce::dce(&mut output);
305    }
306
307    {
308        let _timer = sess.timer("link_fragment_inst_check");
309        simple_passes::check_fragment_insts(sess, &output)?;
310    }
311
312    // HACK(eddyb) this has to run before the `report_zombies` pass, so that
313    // any zombies that are passed as call arguments, but eventually unused,
314    // won't be (incorrectly) considered used.
315    {
316        let _timer = sess.timer("link_remove_unused_params");
317        output = param_weakening::remove_unused_params(output);
318    }
319
320    if opts.early_report_zombies {
321        let _timer = sess.timer("link_report_zombies");
322        zombies::report_zombies(sess, &output)?;
323    }
324
325    if opts.infer_storage_classes {
326        let _timer = sess.timer("specialize_generic_storage_class");
327        // HACK(eddyb) `specializer` requires functions' blocks to be in RPO order
328        // (i.e. `block_ordering_pass`) - this could be relaxed by using RPO visit
329        // inside `specializer`, but this is easier.
330        for func in &mut output.functions {
331            simple_passes::block_ordering_pass(func);
332        }
333        output = specializer::specialize(
334            opts,
335            output,
336            specializer::SimpleSpecialization {
337                specialize_operand: |operand| {
338                    matches!(operand, Operand::StorageClass(StorageClass::Generic))
339                },
340
341                // NOTE(eddyb) this can be anything that is guaranteed to pass
342                // validation - there are no constraints so this is either some
343                // unused pointer, or perhaps one created using `OpConstantNull`
344                // and simply never mixed with pointers that have a storage class.
345                // It would be nice to use `Generic` itself here so that we leave
346                // some kind of indication of it being unconstrained, but `Generic`
347                // requires additional capabilities, so we use `Function` instead.
348                // TODO(eddyb) investigate whether this can end up in a pointer
349                // type that's the value of a module-scoped variable, and whether
350                // `Function` is actually invalid! (may need `Private`)
351                concrete_fallback: Operand::StorageClass(StorageClass::Function),
352            },
353        );
354    }
355
356    // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too!
357    {
358        {
359            let _timer = sess.timer("link_dce-before-inlining");
360            dce::dce(&mut output);
361        }
362
363        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-before-inlining");
364        let mut pointer_to_pointee = FxHashMap::default();
365        let mut constants = FxHashMap::default();
366        let mut u32 = None;
367        for inst in &output.types_global_values {
368            match inst.class.opcode {
369                Op::TypePointer => {
370                    pointer_to_pointee
371                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
372                }
373                Op::TypeInt
374                    if inst.operands[0].unwrap_literal_bit32() == 32
375                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
376                {
377                    assert!(u32.is_none());
378                    u32 = Some(inst.result_id.unwrap());
379                }
380                Op::Constant if u32.is_some() && inst.result_type == u32 => {
381                    let value = inst.operands[0].unwrap_literal_bit32();
382                    constants.insert(inst.result_id.unwrap(), value);
383                }
384                _ => {}
385            }
386        }
387        for func in &mut output.functions {
388            simple_passes::block_ordering_pass(func);
389            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
390            mem2reg::mem2reg(
391                output.header.as_mut().unwrap(),
392                &mut output.types_global_values,
393                &pointer_to_pointee,
394                &constants,
395                func,
396            );
397            destructure_composites::destructure_composites(func);
398        }
399    }
400
401    {
402        let _timer =
403            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-before-inlining");
404        dce::dce(&mut output);
405        duplicates::remove_duplicate_debuginfo(&mut output);
406    }
407
408    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
409    if let Some(dir) = &opts.dump_pre_inline {
410        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
411    }
412
413    {
414        let _timer = sess.timer("link_inline");
415        inline::inline(sess, &mut output)?;
416    }
417
418    {
419        let _timer = sess.timer("link_dce-after-inlining");
420        dce::dce(&mut output);
421    }
422
423    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
424    if let Some(dir) = &opts.dump_post_inline {
425        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
426    }
427
428    {
429        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-after-inlining");
430        let mut pointer_to_pointee = FxHashMap::default();
431        let mut constants = FxHashMap::default();
432        let mut u32 = None;
433        for inst in &output.types_global_values {
434            match inst.class.opcode {
435                Op::TypePointer => {
436                    pointer_to_pointee
437                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
438                }
439                Op::TypeInt
440                    if inst.operands[0].unwrap_literal_bit32() == 32
441                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
442                {
443                    assert!(u32.is_none());
444                    u32 = Some(inst.result_id.unwrap());
445                }
446                Op::Constant if u32.is_some() && inst.result_type == u32 => {
447                    let value = inst.operands[0].unwrap_literal_bit32();
448                    constants.insert(inst.result_id.unwrap(), value);
449                }
450                _ => {}
451            }
452        }
453        for func in &mut output.functions {
454            simple_passes::block_ordering_pass(func);
455            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
456            mem2reg::mem2reg(
457                output.header.as_mut().unwrap(),
458                &mut output.types_global_values,
459                &pointer_to_pointee,
460                &constants,
461                func,
462            );
463            destructure_composites::destructure_composites(func);
464        }
465    }
466
467    {
468        let _timer =
469            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-after-inlining");
470        dce::dce(&mut output);
471        duplicates::remove_duplicate_debuginfo(&mut output);
472    }
473
474    {
475        let _timer = sess.timer("link_remove_non_uniform");
476        simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
477    }
478
479    // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
480    {
481        let (spv_words, module_or_err, lower_from_spv_timer) =
482            spv_module_to_spv_words_and_spirt_module(&output);
483        let module = &mut module_or_err.map_err(|e| {
484            let spv_path = outputs.temp_path_for_diagnostic("spirt-lower-from-spv-input.spv");
485
486            let was_saved_msg =
487                match std::fs::write(&spv_path, spirv_tools::binary::from_binary(&spv_words)) {
488                    Ok(()) => format!("was saved to {}", spv_path.display()),
489                    Err(e) => format!("could not be saved: {e}"),
490                };
491
492            sess.dcx()
493                .struct_err(format!("{e}"))
494                .with_note("while lowering SPIR-V module to SPIR-T (spirt::spv::lower)")
495                .with_note(format!("input SPIR-V module {was_saved_msg}"))
496                .emit()
497        })?;
498
499        let mut dump_guard = SpirtDumpGuard {
500            sess,
501            linker_options: opts,
502            outputs,
503            disambiguated_crate_name_for_dumps,
504
505            module,
506            per_pass_module_for_dumping: vec![],
507            in_progress_pass_name: Cell::new(Some("lower_from_spv")),
508            any_spirt_bugs: false,
509        };
510        let module = &mut *dump_guard.module;
511        // FIXME(eddyb) consider returning a `Drop`-implementing type instead?
512        let before_pass = |pass_name| {
513            let outer_pass_name = dump_guard.in_progress_pass_name.replace(Some(pass_name));
514
515            // FIXME(eddyb) could it make sense to allow these to nest?
516            assert_eq!(outer_pass_name, None);
517
518            sess.timer(pass_name)
519        };
520        let mut after_pass = |module: Option<&spirt::Module>, timer| {
521            drop(timer);
522            let pass_name = dump_guard.in_progress_pass_name.take().unwrap();
523            if let Some(module) = module
524                && opts.dump_spirt_passes.is_some()
525            {
526                dump_guard
527                    .per_pass_module_for_dumping
528                    .push((pass_name.into(), module.clone()));
529            }
530        };
531        // HACK(eddyb) don't dump the unstructured state if not requested, as
532        // after SPIR-T 0.4.0 it's extremely verbose (due to def-use hermeticity).
533        after_pass(
534            (opts.spirt_keep_unstructured_cfg_in_dumps || !opts.structurize).then_some(&*module),
535            lower_from_spv_timer,
536        );
537
538        // NOTE(eddyb) this *must* run on unstructured CFGs, to do its job.
539        // FIXME(eddyb) no longer relying on structurization, try porting this
540        // to replace custom aborts in `Block`s and inject `ExitInvocation`s
541        // after them (truncating the `Block` and/or parent region if necessary).
542        {
543            let timer = before_pass(
544                "spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points",
545            );
546            spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points(opts, module);
547            after_pass(None, timer);
548        }
549
550        if opts.structurize {
551            let timer = before_pass("spirt::legalize::structurize_func_cfgs");
552            spirt::passes::legalize::structurize_func_cfgs(module);
553            after_pass(Some(module), timer);
554        }
555
556        if !opts.spirt_passes.is_empty() {
557            // FIXME(eddyb) why does this focus on functions, it could just be module passes??
558            spirt_passes::run_func_passes(
559                module,
560                &opts.spirt_passes,
561                |name, _module| before_pass(name),
562                &mut after_pass,
563            );
564        }
565
566        {
567            let timer = before_pass("spirt_passes::explicit_layout::erase_when_invalid");
568            spirt_passes::explicit_layout::erase_when_invalid(module);
569            after_pass(Some(module), timer);
570        }
571
572        {
573            let timer = before_pass("spirt_passes::validate");
574            spirt_passes::validate::validate(module);
575            after_pass(Some(module), timer);
576        }
577
578        {
579            let timer = before_pass("spirt_passes::diagnostics::report_diagnostics");
580            spirt_passes::diagnostics::report_diagnostics(sess, opts, module).map_err(
581                |spirt_passes::diagnostics::ReportedDiagnostics {
582                     rustc_errors_guarantee,
583                     any_errors_were_spirt_bugs,
584                 }| {
585                    dump_guard.any_spirt_bugs |= any_errors_were_spirt_bugs;
586                    rustc_errors_guarantee
587                },
588            )?;
589            after_pass(None, timer);
590        }
591
592        // Replace our custom debuginfo instructions just before lifting to SPIR-V.
593        {
594            let timer = before_pass("spirt_passes::debuginfo::convert_custom_debuginfo_to_spv");
595            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
596            after_pass(None, timer);
597        }
598
599        let spv_words = {
600            let timer = before_pass("spirt::Module::lift_to_spv_module_emitter");
601            let spv_words = module.lift_to_spv_module_emitter().unwrap().words;
602            after_pass(None, timer);
603            spv_words
604        };
605        // FIXME(eddyb) dump both SPIR-T and `spv_words` if there's an error here.
606        output = {
607            let _timer = sess.timer("parse-spv_words-from-spirt");
608            crate::link::with_rspirv_loader(|loader| {
609                rspirv::binary::parse_words(&spv_words, loader)
610            })
611            .unwrap()
612        };
613    }
614
615    // Ensure that no references remain, to our custom "extended instruction set".
616    for inst in &output.ext_inst_imports {
617        assert_eq!(inst.class.opcode, Op::ExtInstImport);
618        let ext_inst_set = inst.operands[0].unwrap_literal_string();
619        if ext_inst_set.starts_with(custom_insts::CUSTOM_EXT_INST_SET_PREFIX) {
620            let expected = &custom_insts::CUSTOM_EXT_INST_SET[..];
621            if ext_inst_set == expected {
622                return Err(sess.dcx().err(format!(
623                    "`OpExtInstImport {ext_inst_set:?}` should not have been \
624                         left around after SPIR-T passes"
625                )));
626            } else {
627                return Err(sess.dcx().err(format!(
628                    "unsupported `OpExtInstImport {ext_inst_set:?}`
629                     (expected {expected:?} name - version mismatch?)"
630                )));
631            }
632        }
633    }
634
635    // FIXME(eddyb) rewrite these passes to SPIR-T ones, so we don't have to
636    // parse the output of `spirt::spv::lift` back into `rspirv` - also, for
637    // multi-module, it's much simpler with SPIR-T, just replace `module.exports`
638    // with a single-entry map, run `spirt::spv::lift` (or even `spirt::print`)
639    // on `module`, then put back the full original `module.exports` map.
640    {
641        let _timer = sess.timer("peephole_opts");
642        let types = peephole_opts::collect_types(&output);
643        for func in &mut output.functions {
644            peephole_opts::composite_construct(&types, func);
645            peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
646            peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
647        }
648    }
649
650    {
651        let _timer = sess.timer("link_remove_unused_type_capabilities");
652        simple_passes::remove_unused_type_capabilities(&mut output);
653    }
654
655    {
656        let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
657        entry_interface::gather_all_interface_vars_from_uses(&mut output, opts.preserve_bindings);
658    }
659
660    if opts.spirv_metadata == SpirvMetadata::NameVariables {
661        let _timer = sess.timer("link_name_variables");
662        simple_passes::name_variables_pass(&mut output);
663    }
664
665    {
666        let _timer = sess.timer("link_sort_globals");
667        simple_passes::sort_globals(&mut output);
668    }
669
670    let mut output = if opts.module_output_type == ModuleOutputType::Multiple {
671        let mut file_stem_to_entry_name_and_module = BTreeMap::new();
672        for (i, entry) in output.entry_points.iter().enumerate() {
673            let mut module = output.clone();
674            module.entry_points.clear();
675            module.entry_points.push(entry.clone());
676            let entry_name = entry.operands[2].unwrap_literal_string().to_string();
677            let mut file_stem = OsString::from(
678                sanitize_filename::sanitize_with_options(
679                    &entry_name,
680                    sanitize_filename::Options {
681                        replacement: "-",
682                        ..Default::default()
683                    },
684                )
685                .replace("--", "-"),
686            );
687            // It's always possible to find an unambiguous `file_stem`, but it
688            // may take two tries (or more, in bizzare/adversarial cases).
689            let mut disambiguator = Some(i);
690            loop {
691                use std::collections::btree_map::Entry;
692                match file_stem_to_entry_name_and_module.entry(file_stem) {
693                    Entry::Vacant(entry) => {
694                        entry.insert((entry_name, module));
695                        break;
696                    }
697                    // FIXME(eddyb) false positive: `file_stem` was moved out of,
698                    // so assigning it is necessary, but clippy doesn't know that.
699                    #[allow(clippy::assigning_clones)]
700                    Entry::Occupied(entry) => {
701                        // FIXME(eddyb) there's no way to access the owned key
702                        // passed to `BTreeMap::entry` from `OccupiedEntry`.
703                        file_stem = entry.key().clone();
704                        file_stem.push(".");
705                        match disambiguator.take() {
706                            Some(d) => file_stem.push(d.to_string()),
707                            None => file_stem.push("next"),
708                        }
709                    }
710                }
711            }
712        }
713        LinkResult::MultipleModules {
714            file_stem_to_entry_name_and_module,
715        }
716    } else {
717        LinkResult::SingleModule(Box::new(output))
718    };
719
720    let output_module_iter = match &mut output {
721        LinkResult::SingleModule(m) => Either::Left(std::iter::once((None, &mut **m))),
722        LinkResult::MultipleModules {
723            file_stem_to_entry_name_and_module,
724        } => Either::Right(
725            file_stem_to_entry_name_and_module
726                .iter_mut()
727                .map(|(file_stem, (_, m))| (Some(file_stem), m)),
728        ),
729    };
730    for (file_stem, output) in output_module_iter {
731        // Run DCE again, even if module_output_type == ModuleOutputType::Multiple - the first DCE ran before
732        // structurization and mem2reg (for perf reasons), and mem2reg may remove references to
733        // invalid types, so we need to DCE again.
734        {
735            let _timer = sess.timer("link_dce-post-split");
736            dce::dce(output);
737        }
738
739        // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
740        if let Some(dir) = &opts.dump_post_split {
741            let mut file_name = disambiguated_crate_name_for_dumps.to_os_string();
742            if let Some(file_stem) = file_stem {
743                file_name.push(".");
744                file_name.push(file_stem);
745            }
746
747            dump_spv_and_spirt(output, dir.join(file_name));
748        }
749
750        {
751            let _timer = sess.timer("link_remove_duplicate_debuginfo");
752            duplicates::remove_duplicate_debuginfo(output);
753        }
754
755        if opts.compact_ids {
756            let _timer = sess.timer("link_compact_ids");
757            // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43
758            output.header.as_mut().unwrap().bound = simple_passes::compact_ids(output);
759        };
760
761        // FIXME(eddyb) convert these into actual `OpLine`s with a SPIR-T pass,
762        // but that'd require keeping the modules in SPIR-T form (once lowered),
763        // and never loading them back into `rspirv` once lifted back to SPIR-V.
764        SrcLocDecoration::remove_all(output);
765
766        // FIXME(eddyb) might make more sense to rewrite these away on SPIR-T.
767        ZombieDecoration::remove_all(output);
768    }
769
770    Ok(output)
771}
772
773/// Helper for dumping SPIR-T on drop, which allows panics to also dump,
774/// not just successful compilation (i.e. via `--dump-spirt-passes`).
775struct SpirtDumpGuard<'a> {
776    sess: &'a Session,
777    linker_options: &'a Options,
778    outputs: &'a OutputFilenames,
779    disambiguated_crate_name_for_dumps: &'a OsStr,
780
781    module: &'a mut spirt::Module,
782    per_pass_module_for_dumping: Vec<(Cow<'static, str>, spirt::Module)>,
783    in_progress_pass_name: Cell<Option<&'static str>>,
784    any_spirt_bugs: bool,
785}
786
787impl Drop for SpirtDumpGuard<'_> {
788    fn drop(&mut self) {
789        if std::thread::panicking() {
790            self.any_spirt_bugs = true;
791
792            // HACK(eddyb) the active pass panicked, make sure to include its
793            // (potentially corrupted) state, which will hopefully be printed
794            // later below (with protection against panicking during printing).
795            if let Some(pass_name) = self.in_progress_pass_name.get() {
796                self.per_pass_module_for_dumping.push((
797                    format!("{pass_name} [PANICKED]").into(),
798                    self.module.clone(),
799                ));
800            }
801        }
802
803        let mut dump_spirt_file_path =
804            self.linker_options
805                .dump_spirt_passes
806                .as_ref()
807                .map(|dump_dir| {
808                    dump_dir
809                        .join(self.disambiguated_crate_name_for_dumps)
810                        .with_extension("spirt")
811                });
812
813        // FIXME(eddyb) this won't allow seeing the individual passes, but it's
814        // better than nothing (theoretically the whole "SPIR-T pipeline" could
815        // be put in a loop so that everything is redone with per-pass tracking,
816        // but that requires keeping around e.g. the initial SPIR-V for longer,
817        // and probably invoking the "SPIR-T pipeline" here, as looping is hard).
818        if self.any_spirt_bugs && dump_spirt_file_path.is_none() {
819            if self.per_pass_module_for_dumping.is_empty() {
820                self.per_pass_module_for_dumping
821                    .push(("".into(), self.module.clone()));
822            }
823            dump_spirt_file_path = Some(self.outputs.temp_path_for_diagnostic("spirt"));
824        }
825
826        let Some(dump_spirt_file_path) = &dump_spirt_file_path else {
827            return;
828        };
829
830        for (_, module) in &mut self.per_pass_module_for_dumping {
831            // FIXME(eddyb) consider catching panics in this?
832            self.linker_options.spirt_cleanup_for_dumping(module);
833        }
834
835        let cx = self.module.cx();
836        let versions = self
837            .per_pass_module_for_dumping
838            .iter()
839            .map(|(pass_name, module)| (format!("after {pass_name}"), module));
840
841        let mut panicked_printing_after_passes = None;
842        for truncate_version_count in (1..=versions.len()).rev() {
843            // FIXME(eddyb) tell the user to use `--dump-spirt-passes` if that
844            // wasn't active but a panic happens - on top of that, it may need
845            // quieting the panic handler, likely controlled by a `thread_local!`
846            // (while the panic handler is global), and that would also be useful
847            // for collecting a panic message (assuming any of this is worth it).
848            // HACK(eddyb) for now, keeping the panic handler works out, as the
849            // panic messages would at least be seen by the user.
850            let printed_or_panicked =
851                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
852                    let pretty = spirt::print::Plan::for_versions(
853                        &cx,
854                        versions.clone().take(truncate_version_count),
855                    )
856                    .pretty_print();
857
858                    // FIXME(eddyb) don't allocate whole `String`s here.
859                    std::fs::write(dump_spirt_file_path, pretty.to_string()).unwrap();
860                    std::fs::write(
861                        dump_spirt_file_path.with_extension("spirt.html"),
862                        pretty
863                            .render_to_html()
864                            .with_dark_mode_support()
865                            .to_html_doc(),
866                    )
867                    .unwrap();
868                }));
869            match printed_or_panicked {
870                Ok(()) => {
871                    if truncate_version_count != versions.len() {
872                        panicked_printing_after_passes = Some(
873                            self.per_pass_module_for_dumping[truncate_version_count..]
874                                .iter()
875                                .map(|(pass_name, _)| format!("`{pass_name}`"))
876                                .collect::<Vec<_>>()
877                                .join(", "),
878                        );
879                    }
880                    break;
881                }
882                Err(panic) => {
883                    if truncate_version_count == 1 {
884                        std::panic::resume_unwind(panic);
885                    }
886                }
887            }
888        }
889        if self.any_spirt_bugs || panicked_printing_after_passes.is_some() {
890            let mut note = self.sess.dcx().struct_note("SPIR-T bugs were encountered");
891            if let Some(pass_names) = panicked_printing_after_passes {
892                note.warn(format!(
893                    "SPIR-T pretty-printing panicked after: {pass_names}"
894                ));
895            }
896            note.help(format!(
897                "pretty-printed SPIR-T was saved to {}.html",
898                dump_spirt_file_path.display()
899            ));
900            if self.linker_options.dump_spirt_passes.is_none() {
901                note.help("re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` for more details");
902            }
903            note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues");
904            note.emit();
905        }
906    }
907}