rustc_codegen_spirv/linker/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub(crate) mod dce;
5mod destructure_composites;
6mod duplicates;
7mod entry_interface;
8mod import_export_link;
9mod inline;
10mod ipo;
11mod mem2reg;
12mod param_weakening;
13mod peephole_opts;
14mod simple_passes;
15mod specializer;
16mod spirt_passes;
17mod zombies;
18
19use std::borrow::Cow;
20
21use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
22use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
23use crate::custom_insts;
24use either::Either;
25use rspirv::binary::Assemble;
26use rspirv::dr::{Block, Module, ModuleHeader, Operand};
27use rspirv::spirv::{Op, StorageClass, Word};
28use rustc_data_structures::fx::FxHashMap;
29use rustc_errors::ErrorGuaranteed;
30use rustc_session::Session;
31use rustc_session::config::OutputFilenames;
32use std::cell::Cell;
33use std::collections::BTreeMap;
34use std::ffi::{OsStr, OsString};
35use std::path::PathBuf;
36
37pub type Result<T> = std::result::Result<T, ErrorGuaranteed>;
38
39#[derive(Default)]
40pub struct Options {
41    pub compact_ids: bool,
42    pub early_report_zombies: bool,
43    pub infer_storage_classes: bool,
44    pub structurize: bool,
45    pub spirt_passes: Vec<String>,
46
47    pub abort_strategy: Option<String>,
48    pub module_output_type: ModuleOutputType,
49
50    pub spirv_metadata: SpirvMetadata,
51
52    /// Whether to preserve `LinkageAttributes "..." Export` decorations,
53    /// even after resolving imports to exports.
54    ///
55    /// **Note**: currently only used for unit testing, and not exposed elsewhere.
56    pub keep_link_exports: bool,
57
58    // NOTE(eddyb) these are debugging options that used to be env vars
59    // (for more information see `docs/src/codegen-args.md`).
60    pub dump_post_merge: Option<PathBuf>,
61    pub dump_pre_inline: Option<PathBuf>,
62    pub dump_post_inline: Option<PathBuf>,
63    pub dump_post_split: Option<PathBuf>,
64    pub dump_spirt_passes: Option<PathBuf>,
65    pub spirt_strip_custom_debuginfo_from_dumps: bool,
66    pub spirt_keep_debug_sources_in_dumps: bool,
67    pub spirt_keep_unstructured_cfg_in_dumps: bool,
68    pub specializer_dump_instances: Option<PathBuf>,
69}
70
71pub enum LinkResult {
72    SingleModule(Box<Module>),
73    MultipleModules {
74        /// The "file stem" key is computed from the "entry name" in the value
75        /// (through `sanitize_filename`, replacing invalid chars with `-`),
76        /// but it's used as the map key because it *has to* be unique, even if
77        /// lossy sanitization could have erased distinctions between entry names.
78        file_stem_to_entry_name_and_module: BTreeMap<OsString, (String, Module)>,
79    },
80}
81
82fn id(header: &mut ModuleHeader) -> Word {
83    let result = header.bound;
84    header.bound += 1;
85    result
86}
87
88fn apply_rewrite_rules<'a>(
89    rewrite_rules: &FxHashMap<Word, Word>,
90    blocks: impl IntoIterator<Item = &'a mut Block>,
91) {
92    let all_ids_mut = blocks
93        .into_iter()
94        .flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
95        .flat_map(|inst| {
96            inst.result_id
97                .iter_mut()
98                .chain(inst.result_type.iter_mut())
99                .chain(
100                    inst.operands
101                        .iter_mut()
102                        .filter_map(|op| op.id_ref_any_mut()),
103                )
104        });
105    for id in all_ids_mut {
106        if let Some(&rewrite) = rewrite_rules.get(id) {
107            *id = rewrite;
108        }
109    }
110}
111
112fn get_names(module: &Module) -> FxHashMap<Word, &str> {
113    let entry_names = module
114        .entry_points
115        .iter()
116        .filter(|i| i.class.opcode == Op::EntryPoint)
117        .map(|i| {
118            (
119                i.operands[1].unwrap_id_ref(),
120                i.operands[2].unwrap_literal_string(),
121            )
122        });
123    let debug_names = module
124        .debug_names
125        .iter()
126        .filter(|i| i.class.opcode == Op::Name)
127        .map(|i| {
128            (
129                i.operands[0].unwrap_id_ref(),
130                i.operands[1].unwrap_literal_string(),
131            )
132        });
133    // items later on take priority
134    entry_names.chain(debug_names).collect()
135}
136
137fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
138    names.get(&id).map_or_else(
139        || Cow::Owned(format!("Unnamed function ID %{id}")),
140        |&s| Cow::Borrowed(s),
141    )
142}
143
144impl Options {
145    // FIXME(eddyb) using a method on this type seems a bit sketchy.
146    fn spirt_cleanup_for_dumping(&self, module: &mut spirt::Module) {
147        if self.spirt_strip_custom_debuginfo_from_dumps {
148            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
149        }
150        if !self.spirt_keep_debug_sources_in_dumps {
151            const DOTS: &str = "⋯";
152            let dots_interned_str = module.cx().intern(DOTS);
153            let spirt::ModuleDebugInfo::Spv(debuginfo) = &mut module.debug_info;
154            for sources in debuginfo.source_languages.values_mut() {
155                for file in sources.file_contents.values_mut() {
156                    *file = DOTS.into();
157                }
158                sources.file_contents.insert(
159                    dots_interned_str,
160                    "sources hidden, to show them use \
161                     `RUSTGPU_CODEGEN_ARGS=--spirt-keep-debug-sources-in-dumps`"
162                        .into(),
163                );
164            }
165        }
166    }
167}
168
169pub fn link(
170    sess: &Session,
171    mut inputs: Vec<Module>,
172    opts: &Options,
173    outputs: &OutputFilenames,
174    disambiguated_crate_name_for_dumps: &OsStr,
175) -> Result<LinkResult> {
176    // HACK(eddyb) this is defined here to allow SPIR-T pretty-printing to apply
177    // to SPIR-V being dumped, outside of e.g. `--dump-spirt-passes`.
178    // FIXME(eddyb) this isn't used everywhere, sadly - to find those, search
179    // elsewhere for `.assemble()` and/or `spirv_tools::binary::from_binary`.
180    let spv_module_to_spv_words_and_spirt_module = |spv_module: &Module| {
181        let spv_words;
182        let spv_bytes = {
183            let _timer = sess.timer("assemble-to-spv_bytes-for-spirt");
184            spv_words = spv_module.assemble();
185            // FIXME(eddyb) this is wastefully cloning all the bytes, but also
186            // `spirt::Module` should have a method that takes `Vec<u32>`.
187            spirv_tools::binary::from_binary(&spv_words).to_vec()
188        };
189
190        // FIXME(eddyb) should've really been "spirt::Module::lower_from_spv_bytes".
191        let lower_from_spv_timer = sess.timer("spirt::Module::lower_from_spv_file");
192        let cx = std::rc::Rc::new(spirt::Context::new());
193        crate::custom_insts::register_to_spirt_context(&cx);
194        (
195            spv_words,
196            spirt::Module::lower_from_spv_bytes(cx, spv_bytes),
197            // HACK(eddyb) this is only returned for `SpirtDumpGuard`.
198            lower_from_spv_timer,
199        )
200    };
201
202    // FIXME(eddyb) deduplicate with `SpirtDumpGuard`.
203    let dump_spv_and_spirt = |spv_module: &Module, dump_file_path_stem: PathBuf| {
204        let (spv_words, spirt_module_or_err, _) =
205            spv_module_to_spv_words_and_spirt_module(spv_module);
206        std::fs::write(
207            dump_file_path_stem.with_extension("spv"),
208            spirv_tools::binary::from_binary(&spv_words),
209        )
210        .unwrap();
211
212        // FIXME(eddyb) reify SPIR-V -> SPIR-T errors so they're easier to debug.
213        if let Ok(mut module) = spirt_module_or_err {
214            // HACK(eddyb) avoid pretty-printing massive amounts of unused SPIR-T.
215            spirt::passes::link::minimize_exports(&mut module, |export_key| {
216                matches!(export_key, spirt::ExportKey::SpvEntryPoint { .. })
217            });
218
219            opts.spirt_cleanup_for_dumping(&mut module);
220
221            let pretty = spirt::print::Plan::for_module(&module).pretty_print();
222
223            // FIXME(eddyb) don't allocate whole `String`s here.
224            std::fs::write(
225                dump_file_path_stem.with_extension("spirt"),
226                pretty.to_string(),
227            )
228            .unwrap();
229            std::fs::write(
230                dump_file_path_stem.with_extension("spirt.html"),
231                pretty
232                    .render_to_html()
233                    .with_dark_mode_support()
234                    .to_html_doc(),
235            )
236            .unwrap();
237        }
238    };
239
240    let mut output = {
241        let _timer = sess.timer("link_merge");
242        // shift all the ids
243        let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
244        let version = inputs[0].header.as_ref().unwrap().version();
245
246        for module in inputs.iter_mut().skip(1) {
247            simple_passes::shift_ids(module, bound);
248            bound += module.header.as_ref().unwrap().bound - 1;
249            let this_version = module.header.as_ref().unwrap().version();
250            if version != this_version {
251                return Err(sess.dcx().err(format!(
252                    "cannot link two modules with different SPIR-V versions: v{}.{} and v{}.{}",
253                    version.0, version.1, this_version.0, this_version.1
254                )));
255            }
256        }
257
258        // merge the binaries
259        let mut output = crate::link::with_rspirv_loader(|loader| {
260            for module in inputs {
261                for inst in module.all_inst_iter() {
262                    use rspirv::binary::ParseAction;
263                    match loader.consume_instruction(inst.clone()) {
264                        ParseAction::Continue => {}
265                        ParseAction::Stop => unreachable!(),
266                        ParseAction::Error(err) => return Err(err),
267                    }
268                }
269            }
270            Ok(())
271        })
272        .unwrap();
273
274        let mut header = ModuleHeader::new(bound + 1);
275        header.set_version(version.0, version.1);
276        header.generator = 0x001B_0000;
277        output.header = Some(header);
278        output
279    };
280
281    if let Some(dir) = &opts.dump_post_merge {
282        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
283    }
284
285    // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
286    {
287        let _timer = sess.timer("link_remove_duplicates");
288        duplicates::remove_duplicate_extensions(&mut output);
289        duplicates::remove_duplicate_capabilities(&mut output);
290        duplicates::remove_duplicate_ext_inst_imports(&mut output);
291        duplicates::remove_duplicate_types(&mut output);
292        // jb-todo: strip identical OpDecoration / OpDecorationGroups
293    }
294
295    // find import / export pairs
296    {
297        let _timer = sess.timer("link_find_pairs");
298        import_export_link::run(opts, sess, &mut output)?;
299    }
300
301    {
302        let _timer = sess.timer("link_dce-post-link");
303        dce::dce(&mut output);
304    }
305
306    {
307        let _timer = sess.timer("link_fragment_inst_check");
308        simple_passes::check_fragment_insts(sess, &output)?;
309    }
310
311    // HACK(eddyb) this has to run before the `report_zombies` pass, so that
312    // any zombies that are passed as call arguments, but eventually unused,
313    // won't be (incorrectly) considered used.
314    {
315        let _timer = sess.timer("link_remove_unused_params");
316        output = param_weakening::remove_unused_params(output);
317    }
318
319    if opts.early_report_zombies {
320        let _timer = sess.timer("link_report_zombies");
321        zombies::report_zombies(sess, &output)?;
322    }
323
324    if opts.infer_storage_classes {
325        let _timer = sess.timer("specialize_generic_storage_class");
326        // HACK(eddyb) `specializer` requires functions' blocks to be in RPO order
327        // (i.e. `block_ordering_pass`) - this could be relaxed by using RPO visit
328        // inside `specializer`, but this is easier.
329        for func in &mut output.functions {
330            simple_passes::block_ordering_pass(func);
331        }
332        output = specializer::specialize(
333            opts,
334            output,
335            specializer::SimpleSpecialization {
336                specialize_operand: |operand| {
337                    matches!(operand, Operand::StorageClass(StorageClass::Generic))
338                },
339
340                // NOTE(eddyb) this can be anything that is guaranteed to pass
341                // validation - there are no constraints so this is either some
342                // unused pointer, or perhaps one created using `OpConstantNull`
343                // and simply never mixed with pointers that have a storage class.
344                // It would be nice to use `Generic` itself here so that we leave
345                // some kind of indication of it being unconstrained, but `Generic`
346                // requires additional capabilities, so we use `Function` instead.
347                // TODO(eddyb) investigate whether this can end up in a pointer
348                // type that's the value of a module-scoped variable, and whether
349                // `Function` is actually invalid! (may need `Private`)
350                concrete_fallback: Operand::StorageClass(StorageClass::Function),
351            },
352        );
353    }
354
355    // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too!
356    {
357        {
358            let _timer = sess.timer("link_dce-before-inlining");
359            dce::dce(&mut output);
360        }
361
362        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-before-inlining");
363        let mut pointer_to_pointee = FxHashMap::default();
364        let mut constants = FxHashMap::default();
365        let mut u32 = None;
366        for inst in &output.types_global_values {
367            match inst.class.opcode {
368                Op::TypePointer => {
369                    pointer_to_pointee
370                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
371                }
372                Op::TypeInt
373                    if inst.operands[0].unwrap_literal_bit32() == 32
374                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
375                {
376                    assert!(u32.is_none());
377                    u32 = Some(inst.result_id.unwrap());
378                }
379                Op::Constant if u32.is_some() && inst.result_type == u32 => {
380                    let value = inst.operands[0].unwrap_literal_bit32();
381                    constants.insert(inst.result_id.unwrap(), value);
382                }
383                _ => {}
384            }
385        }
386        for func in &mut output.functions {
387            simple_passes::block_ordering_pass(func);
388            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
389            mem2reg::mem2reg(
390                output.header.as_mut().unwrap(),
391                &mut output.types_global_values,
392                &pointer_to_pointee,
393                &constants,
394                func,
395            );
396            destructure_composites::destructure_composites(func);
397        }
398    }
399
400    {
401        let _timer =
402            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-before-inlining");
403        dce::dce(&mut output);
404        duplicates::remove_duplicate_debuginfo(&mut output);
405    }
406
407    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
408    if let Some(dir) = &opts.dump_pre_inline {
409        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
410    }
411
412    {
413        let _timer = sess.timer("link_inline");
414        inline::inline(sess, &mut output)?;
415    }
416
417    {
418        let _timer = sess.timer("link_dce-after-inlining");
419        dce::dce(&mut output);
420    }
421
422    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
423    if let Some(dir) = &opts.dump_post_inline {
424        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
425    }
426
427    {
428        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-after-inlining");
429        let mut pointer_to_pointee = FxHashMap::default();
430        let mut constants = FxHashMap::default();
431        let mut u32 = None;
432        for inst in &output.types_global_values {
433            match inst.class.opcode {
434                Op::TypePointer => {
435                    pointer_to_pointee
436                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
437                }
438                Op::TypeInt
439                    if inst.operands[0].unwrap_literal_bit32() == 32
440                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
441                {
442                    assert!(u32.is_none());
443                    u32 = Some(inst.result_id.unwrap());
444                }
445                Op::Constant if u32.is_some() && inst.result_type == u32 => {
446                    let value = inst.operands[0].unwrap_literal_bit32();
447                    constants.insert(inst.result_id.unwrap(), value);
448                }
449                _ => {}
450            }
451        }
452        for func in &mut output.functions {
453            simple_passes::block_ordering_pass(func);
454            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
455            mem2reg::mem2reg(
456                output.header.as_mut().unwrap(),
457                &mut output.types_global_values,
458                &pointer_to_pointee,
459                &constants,
460                func,
461            );
462            destructure_composites::destructure_composites(func);
463        }
464    }
465
466    {
467        let _timer =
468            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-after-inlining");
469        dce::dce(&mut output);
470        duplicates::remove_duplicate_debuginfo(&mut output);
471    }
472
473    {
474        let _timer = sess.timer("link_remove_non_uniform");
475        simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
476    }
477
478    // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
479    {
480        let (spv_words, module_or_err, lower_from_spv_timer) =
481            spv_module_to_spv_words_and_spirt_module(&output);
482        let module = &mut module_or_err.map_err(|e| {
483            let spv_path = outputs.temp_path_for_diagnostic("spirt-lower-from-spv-input.spv");
484
485            let was_saved_msg =
486                match std::fs::write(&spv_path, spirv_tools::binary::from_binary(&spv_words)) {
487                    Ok(()) => format!("was saved to {}", spv_path.display()),
488                    Err(e) => format!("could not be saved: {e}"),
489                };
490
491            sess.dcx()
492                .struct_err(format!("{e}"))
493                .with_note("while lowering SPIR-V module to SPIR-T (spirt::spv::lower)")
494                .with_note(format!("input SPIR-V module {was_saved_msg}"))
495                .emit()
496        })?;
497
498        let mut dump_guard = SpirtDumpGuard {
499            sess,
500            linker_options: opts,
501            outputs,
502            disambiguated_crate_name_for_dumps,
503
504            module,
505            per_pass_module_for_dumping: vec![],
506            in_progress_pass_name: Cell::new(Some("lower_from_spv")),
507            any_spirt_bugs: false,
508        };
509        let module = &mut *dump_guard.module;
510        // FIXME(eddyb) consider returning a `Drop`-implementing type instead?
511        let before_pass = |pass_name| {
512            let outer_pass_name = dump_guard.in_progress_pass_name.replace(Some(pass_name));
513
514            // FIXME(eddyb) could it make sense to allow these to nest?
515            assert_eq!(outer_pass_name, None);
516
517            sess.timer(pass_name)
518        };
519        let mut after_pass = |module: Option<&spirt::Module>, timer| {
520            drop(timer);
521            let pass_name = dump_guard.in_progress_pass_name.take().unwrap();
522            if let Some(module) = module
523                && opts.dump_spirt_passes.is_some()
524            {
525                dump_guard
526                    .per_pass_module_for_dumping
527                    .push((pass_name.into(), module.clone()));
528            }
529        };
530        // HACK(eddyb) don't dump the unstructured state if not requested, as
531        // after SPIR-T 0.4.0 it's extremely verbose (due to def-use hermeticity).
532        after_pass(
533            (opts.spirt_keep_unstructured_cfg_in_dumps || !opts.structurize).then_some(&*module),
534            lower_from_spv_timer,
535        );
536
537        // NOTE(eddyb) this *must* run on unstructured CFGs, to do its job.
538        // FIXME(eddyb) no longer relying on structurization, try porting this
539        // to replace custom aborts in `Block`s and inject `ExitInvocation`s
540        // after them (truncating the `Block` and/or parent region if necessary).
541        {
542            let timer = before_pass(
543                "spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points",
544            );
545            spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points(opts, module);
546            after_pass(None, timer);
547        }
548
549        if opts.structurize {
550            let timer = before_pass("spirt::legalize::structurize_func_cfgs");
551            spirt::passes::legalize::structurize_func_cfgs(module);
552            after_pass(Some(module), timer);
553        }
554
555        if !opts.spirt_passes.is_empty() {
556            // FIXME(eddyb) why does this focus on functions, it could just be module passes??
557            spirt_passes::run_func_passes(
558                module,
559                &opts.spirt_passes,
560                |name, _module| before_pass(name),
561                &mut after_pass,
562            );
563        }
564
565        {
566            let timer = before_pass("spirt_passes::validate");
567            spirt_passes::validate::validate(module);
568            after_pass(Some(module), timer);
569        }
570
571        {
572            let timer = before_pass("spirt_passes::diagnostics::report_diagnostics");
573            spirt_passes::diagnostics::report_diagnostics(sess, opts, module).map_err(
574                |spirt_passes::diagnostics::ReportedDiagnostics {
575                     rustc_errors_guarantee,
576                     any_errors_were_spirt_bugs,
577                 }| {
578                    dump_guard.any_spirt_bugs |= any_errors_were_spirt_bugs;
579                    rustc_errors_guarantee
580                },
581            )?;
582            after_pass(None, timer);
583        }
584
585        // Replace our custom debuginfo instructions just before lifting to SPIR-V.
586        {
587            let timer = before_pass("spirt_passes::debuginfo::convert_custom_debuginfo_to_spv");
588            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
589            after_pass(None, timer);
590        }
591
592        let spv_words = {
593            let timer = before_pass("spirt::Module::lift_to_spv_module_emitter");
594            let spv_words = module.lift_to_spv_module_emitter().unwrap().words;
595            after_pass(None, timer);
596            spv_words
597        };
598        // FIXME(eddyb) dump both SPIR-T and `spv_words` if there's an error here.
599        output = {
600            let _timer = sess.timer("parse-spv_words-from-spirt");
601            crate::link::with_rspirv_loader(|loader| {
602                rspirv::binary::parse_words(&spv_words, loader)
603            })
604            .unwrap()
605        };
606    }
607
608    // Ensure that no references remain, to our custom "extended instruction set".
609    for inst in &output.ext_inst_imports {
610        assert_eq!(inst.class.opcode, Op::ExtInstImport);
611        let ext_inst_set = inst.operands[0].unwrap_literal_string();
612        if ext_inst_set.starts_with(custom_insts::CUSTOM_EXT_INST_SET_PREFIX) {
613            let expected = &custom_insts::CUSTOM_EXT_INST_SET[..];
614            if ext_inst_set == expected {
615                return Err(sess.dcx().err(format!(
616                    "`OpExtInstImport {ext_inst_set:?}` should not have been \
617                         left around after SPIR-T passes"
618                )));
619            } else {
620                return Err(sess.dcx().err(format!(
621                    "unsupported `OpExtInstImport {ext_inst_set:?}`
622                     (expected {expected:?} name - version mismatch?)"
623                )));
624            }
625        }
626    }
627
628    // FIXME(eddyb) rewrite these passes to SPIR-T ones, so we don't have to
629    // parse the output of `spirt::spv::lift` back into `rspirv` - also, for
630    // multi-module, it's much simpler with SPIR-T, just replace `module.exports`
631    // with a single-entry map, run `spirt::spv::lift` (or even `spirt::print`)
632    // on `module`, then put back the full original `module.exports` map.
633    {
634        let _timer = sess.timer("peephole_opts");
635        let types = peephole_opts::collect_types(&output);
636        for func in &mut output.functions {
637            peephole_opts::composite_construct(&types, func);
638            peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
639            peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
640        }
641    }
642
643    {
644        let _timer = sess.timer("link_remove_unused_type_capabilities");
645        simple_passes::remove_unused_type_capabilities(&mut output);
646    }
647
648    {
649        let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
650        entry_interface::gather_all_interface_vars_from_uses(&mut output);
651    }
652
653    if opts.spirv_metadata == SpirvMetadata::NameVariables {
654        let _timer = sess.timer("link_name_variables");
655        simple_passes::name_variables_pass(&mut output);
656    }
657
658    {
659        let _timer = sess.timer("link_sort_globals");
660        simple_passes::sort_globals(&mut output);
661    }
662
663    let mut output = if opts.module_output_type == ModuleOutputType::Multiple {
664        let mut file_stem_to_entry_name_and_module = BTreeMap::new();
665        for (i, entry) in output.entry_points.iter().enumerate() {
666            let mut module = output.clone();
667            module.entry_points.clear();
668            module.entry_points.push(entry.clone());
669            let entry_name = entry.operands[2].unwrap_literal_string().to_string();
670            let mut file_stem = OsString::from(
671                sanitize_filename::sanitize_with_options(
672                    &entry_name,
673                    sanitize_filename::Options {
674                        replacement: "-",
675                        ..Default::default()
676                    },
677                )
678                .replace("--", "-"),
679            );
680            // It's always possible to find an unambiguous `file_stem`, but it
681            // may take two tries (or more, in bizzare/adversarial cases).
682            let mut disambiguator = Some(i);
683            loop {
684                use std::collections::btree_map::Entry;
685                match file_stem_to_entry_name_and_module.entry(file_stem) {
686                    Entry::Vacant(entry) => {
687                        entry.insert((entry_name, module));
688                        break;
689                    }
690                    // FIXME(eddyb) false positive: `file_stem` was moved out of,
691                    // so assigning it is necessary, but clippy doesn't know that.
692                    #[allow(clippy::assigning_clones)]
693                    Entry::Occupied(entry) => {
694                        // FIXME(eddyb) there's no way to access the owned key
695                        // passed to `BTreeMap::entry` from `OccupiedEntry`.
696                        file_stem = entry.key().clone();
697                        file_stem.push(".");
698                        match disambiguator.take() {
699                            Some(d) => file_stem.push(d.to_string()),
700                            None => file_stem.push("next"),
701                        }
702                    }
703                }
704            }
705        }
706        LinkResult::MultipleModules {
707            file_stem_to_entry_name_and_module,
708        }
709    } else {
710        LinkResult::SingleModule(Box::new(output))
711    };
712
713    let output_module_iter = match &mut output {
714        LinkResult::SingleModule(m) => Either::Left(std::iter::once((None, &mut **m))),
715        LinkResult::MultipleModules {
716            file_stem_to_entry_name_and_module,
717        } => Either::Right(
718            file_stem_to_entry_name_and_module
719                .iter_mut()
720                .map(|(file_stem, (_, m))| (Some(file_stem), m)),
721        ),
722    };
723    for (file_stem, output) in output_module_iter {
724        // Run DCE again, even if module_output_type == ModuleOutputType::Multiple - the first DCE ran before
725        // structurization and mem2reg (for perf reasons), and mem2reg may remove references to
726        // invalid types, so we need to DCE again.
727        {
728            let _timer = sess.timer("link_dce-post-split");
729            dce::dce(output);
730        }
731
732        // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
733        if let Some(dir) = &opts.dump_post_split {
734            let mut file_name = disambiguated_crate_name_for_dumps.to_os_string();
735            if let Some(file_stem) = file_stem {
736                file_name.push(".");
737                file_name.push(file_stem);
738            }
739
740            dump_spv_and_spirt(output, dir.join(file_name));
741        }
742
743        {
744            let _timer = sess.timer("link_remove_duplicate_debuginfo");
745            duplicates::remove_duplicate_debuginfo(output);
746        }
747
748        if opts.compact_ids {
749            let _timer = sess.timer("link_compact_ids");
750            // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43
751            output.header.as_mut().unwrap().bound = simple_passes::compact_ids(output);
752        };
753
754        // FIXME(eddyb) convert these into actual `OpLine`s with a SPIR-T pass,
755        // but that'd require keeping the modules in SPIR-T form (once lowered),
756        // and never loading them back into `rspirv` once lifted back to SPIR-V.
757        SrcLocDecoration::remove_all(output);
758
759        // FIXME(eddyb) might make more sense to rewrite these away on SPIR-T.
760        ZombieDecoration::remove_all(output);
761    }
762
763    Ok(output)
764}
765
766/// Helper for dumping SPIR-T on drop, which allows panics to also dump,
767/// not just successful compilation (i.e. via `--dump-spirt-passes`).
768struct SpirtDumpGuard<'a> {
769    sess: &'a Session,
770    linker_options: &'a Options,
771    outputs: &'a OutputFilenames,
772    disambiguated_crate_name_for_dumps: &'a OsStr,
773
774    module: &'a mut spirt::Module,
775    per_pass_module_for_dumping: Vec<(Cow<'static, str>, spirt::Module)>,
776    in_progress_pass_name: Cell<Option<&'static str>>,
777    any_spirt_bugs: bool,
778}
779
780impl Drop for SpirtDumpGuard<'_> {
781    fn drop(&mut self) {
782        if std::thread::panicking() {
783            self.any_spirt_bugs = true;
784
785            // HACK(eddyb) the active pass panicked, make sure to include its
786            // (potentially corrupted) state, which will hopefully be printed
787            // later below (with protection against panicking during printing).
788            if let Some(pass_name) = self.in_progress_pass_name.get() {
789                self.per_pass_module_for_dumping.push((
790                    format!("{pass_name} [PANICKED]").into(),
791                    self.module.clone(),
792                ));
793            }
794        }
795
796        let mut dump_spirt_file_path =
797            self.linker_options
798                .dump_spirt_passes
799                .as_ref()
800                .map(|dump_dir| {
801                    dump_dir
802                        .join(self.disambiguated_crate_name_for_dumps)
803                        .with_extension("spirt")
804                });
805
806        // FIXME(eddyb) this won't allow seeing the individual passes, but it's
807        // better than nothing (theoretically the whole "SPIR-T pipeline" could
808        // be put in a loop so that everything is redone with per-pass tracking,
809        // but that requires keeping around e.g. the initial SPIR-V for longer,
810        // and probably invoking the "SPIR-T pipeline" here, as looping is hard).
811        if self.any_spirt_bugs && dump_spirt_file_path.is_none() {
812            if self.per_pass_module_for_dumping.is_empty() {
813                self.per_pass_module_for_dumping
814                    .push(("".into(), self.module.clone()));
815            }
816            dump_spirt_file_path = Some(self.outputs.temp_path_for_diagnostic("spirt"));
817        }
818
819        let Some(dump_spirt_file_path) = &dump_spirt_file_path else {
820            return;
821        };
822
823        for (_, module) in &mut self.per_pass_module_for_dumping {
824            // FIXME(eddyb) consider catching panics in this?
825            self.linker_options.spirt_cleanup_for_dumping(module);
826        }
827
828        let cx = self.module.cx();
829        let versions = self
830            .per_pass_module_for_dumping
831            .iter()
832            .map(|(pass_name, module)| (format!("after {pass_name}"), module));
833
834        let mut panicked_printing_after_passes = None;
835        for truncate_version_count in (1..=versions.len()).rev() {
836            // FIXME(eddyb) tell the user to use `--dump-spirt-passes` if that
837            // wasn't active but a panic happens - on top of that, it may need
838            // quieting the panic handler, likely controlled by a `thread_local!`
839            // (while the panic handler is global), and that would also be useful
840            // for collecting a panic message (assuming any of this is worth it).
841            // HACK(eddyb) for now, keeping the panic handler works out, as the
842            // panic messages would at least be seen by the user.
843            let printed_or_panicked =
844                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
845                    let pretty = spirt::print::Plan::for_versions(
846                        &cx,
847                        versions.clone().take(truncate_version_count),
848                    )
849                    .pretty_print();
850
851                    // FIXME(eddyb) don't allocate whole `String`s here.
852                    std::fs::write(dump_spirt_file_path, pretty.to_string()).unwrap();
853                    std::fs::write(
854                        dump_spirt_file_path.with_extension("spirt.html"),
855                        pretty
856                            .render_to_html()
857                            .with_dark_mode_support()
858                            .to_html_doc(),
859                    )
860                    .unwrap();
861                }));
862            match printed_or_panicked {
863                Ok(()) => {
864                    if truncate_version_count != versions.len() {
865                        panicked_printing_after_passes = Some(
866                            self.per_pass_module_for_dumping[truncate_version_count..]
867                                .iter()
868                                .map(|(pass_name, _)| format!("`{pass_name}`"))
869                                .collect::<Vec<_>>()
870                                .join(", "),
871                        );
872                    }
873                    break;
874                }
875                Err(panic) => {
876                    if truncate_version_count == 1 {
877                        std::panic::resume_unwind(panic);
878                    }
879                }
880            }
881        }
882        if self.any_spirt_bugs || panicked_printing_after_passes.is_some() {
883            let mut note = self.sess.dcx().struct_note("SPIR-T bugs were encountered");
884            if let Some(pass_names) = panicked_printing_after_passes {
885                note.warn(format!(
886                    "SPIR-T pretty-printing panicked after: {pass_names}"
887                ));
888            }
889            note.help(format!(
890                "pretty-printed SPIR-T was saved to {}.html",
891                dump_spirt_file_path.display()
892            ));
893            if self.linker_options.dump_spirt_passes.is_none() {
894                note.help("re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` for more details");
895            }
896            note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues");
897            note.emit();
898        }
899    }
900}