Skip to main content

rustc_codegen_spirv/linker/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub(crate) mod dce;
5mod destructure_composites;
6mod duplicates;
7mod entry_interface;
8mod import_export_link;
9mod inline;
10mod ipo;
11mod mem2reg;
12mod param_weakening;
13mod peephole_opts;
14mod simple_passes;
15mod specializer;
16mod spirt_passes;
17mod zombies;
18
19use std::borrow::Cow;
20
21use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
22use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
23use crate::custom_insts;
24use either::Either;
25use rspirv::binary::Assemble;
26use rspirv::dr::{Block, Module, ModuleHeader, Operand};
27use rspirv::spirv::{Op, StorageClass, Word};
28use rustc_data_structures::fx::FxHashMap;
29use rustc_errors::ErrorGuaranteed;
30use rustc_session::Session;
31use rustc_session::config::OutputFilenames;
32use std::cell::Cell;
33use std::collections::BTreeMap;
34use std::ffi::{OsStr, OsString};
35use std::path::PathBuf;
36
37pub type Result<T> = std::result::Result<T, ErrorGuaranteed>;
38
39#[derive(Default)]
40pub struct Options {
41    pub compact_ids: bool,
42    pub early_report_zombies: bool,
43    pub infer_storage_classes: bool,
44    pub structurize: bool,
45    pub preserve_bindings: bool,
46    pub spirt_passes: Vec<String>,
47
48    pub abort_strategy: Option<String>,
49    pub module_output_type: ModuleOutputType,
50
51    pub spirv_metadata: SpirvMetadata,
52
53    /// Whether to preserve `LinkageAttributes "..." Export` decorations,
54    /// even after resolving imports to exports.
55    ///
56    /// **Note**: currently only used for unit testing, and not exposed elsewhere.
57    pub keep_link_exports: bool,
58
59    // NOTE(eddyb) these are debugging options that used to be env vars
60    // (for more information see `docs/src/codegen-args.md`).
61    pub dump_post_merge: Option<PathBuf>,
62    pub dump_pre_inline: Option<PathBuf>,
63    pub dump_post_inline: Option<PathBuf>,
64    pub dump_post_split: Option<PathBuf>,
65    pub dump_spirt: Option<(PathBuf, DumpSpirtMode)>,
66    pub spirt_strip_custom_debuginfo_from_dumps: bool,
67    pub spirt_keep_debug_sources_in_dumps: bool,
68    pub spirt_keep_unstructured_cfg_in_dumps: bool,
69    pub specializer_dump_instances: Option<PathBuf>,
70}
71
72#[derive(Copy, Clone, PartialEq, Eq)]
73pub enum DumpSpirtMode {
74    AllPasses,
75    OnlyFinal,
76}
77
78pub enum LinkResult {
79    SingleModule(Box<Module>),
80    MultipleModules {
81        /// The "file stem" key is computed from the "entry name" in the value
82        /// (through `sanitize_filename`, replacing invalid chars with `-`),
83        /// but it's used as the map key because it *has to* be unique, even if
84        /// lossy sanitization could have erased distinctions between entry names.
85        file_stem_to_entry_name_and_module: BTreeMap<OsString, (String, Module)>,
86    },
87}
88
89fn id(header: &mut ModuleHeader) -> Word {
90    let result = header.bound;
91    header.bound += 1;
92    result
93}
94
95fn apply_rewrite_rules<'a>(
96    rewrite_rules: &FxHashMap<Word, Word>,
97    blocks: impl IntoIterator<Item = &'a mut Block>,
98) {
99    let all_ids_mut = blocks
100        .into_iter()
101        .flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
102        .flat_map(|inst| {
103            inst.result_id
104                .iter_mut()
105                .chain(inst.result_type.iter_mut())
106                .chain(
107                    inst.operands
108                        .iter_mut()
109                        .filter_map(|op| op.id_ref_any_mut()),
110                )
111        });
112    for id in all_ids_mut {
113        while let Some(&rewrite) = rewrite_rules.get(id) {
114            *id = rewrite;
115        }
116    }
117}
118
119fn get_names(module: &Module) -> FxHashMap<Word, &str> {
120    let entry_names = module
121        .entry_points
122        .iter()
123        .filter(|i| i.class.opcode == Op::EntryPoint)
124        .map(|i| {
125            (
126                i.operands[1].unwrap_id_ref(),
127                i.operands[2].unwrap_literal_string(),
128            )
129        });
130    let debug_names = module
131        .debug_names
132        .iter()
133        .filter(|i| i.class.opcode == Op::Name)
134        .map(|i| {
135            (
136                i.operands[0].unwrap_id_ref(),
137                i.operands[1].unwrap_literal_string(),
138            )
139        });
140    // items later on take priority
141    entry_names.chain(debug_names).collect()
142}
143
144fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
145    names.get(&id).map_or_else(
146        || Cow::Owned(format!("Unnamed function ID %{id}")),
147        |&s| Cow::Borrowed(s),
148    )
149}
150
151impl Options {
152    // FIXME(eddyb) using a method on this type seems a bit sketchy.
153    fn spirt_cleanup_for_dumping(&self, module: &mut spirt::Module) {
154        if self.spirt_strip_custom_debuginfo_from_dumps {
155            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
156        }
157        if !self.spirt_keep_debug_sources_in_dumps {
158            const DOTS: &str = "⋯";
159            let dots_interned_str = module.cx().intern(DOTS);
160            let spirt::ModuleDebugInfo::Spv(debuginfo) = &mut module.debug_info;
161            for sources in debuginfo.source_languages.values_mut() {
162                for file in sources.file_contents.values_mut() {
163                    *file = DOTS.into();
164                }
165                sources.file_contents.insert(
166                    dots_interned_str,
167                    "sources hidden, to show them use \
168                     `RUSTGPU_CODEGEN_ARGS=--spirt-keep-debug-sources-in-dumps`"
169                        .into(),
170                );
171            }
172        }
173    }
174}
175
176pub fn link(
177    sess: &Session,
178    mut inputs: Vec<Module>,
179    opts: &Options,
180    outputs: &OutputFilenames,
181    disambiguated_crate_name_for_dumps: &OsStr,
182) -> Result<LinkResult> {
183    // HACK(eddyb) this is defined here to allow SPIR-T pretty-printing to apply
184    // to SPIR-V being dumped, outside of e.g. `--dump-spirt-passes`.
185    // FIXME(eddyb) this isn't used everywhere, sadly - to find those, search
186    // elsewhere for `.assemble()` and/or `spirv_tools::binary::from_binary`.
187    let spv_module_to_spv_words_and_spirt_module = |spv_module: &Module| {
188        let spv_words;
189        let spv_bytes = {
190            let _timer = sess.timer("assemble-to-spv_bytes-for-spirt");
191            spv_words = spv_module.assemble();
192            // FIXME(eddyb) this is wastefully cloning all the bytes, but also
193            // `spirt::Module` should have a method that takes `Vec<u32>`.
194            spirv_tools::binary::from_binary(&spv_words).to_vec()
195        };
196
197        // FIXME(eddyb) should've really been "spirt::Module::lower_from_spv_bytes".
198        let lower_from_spv_timer = sess.timer("spirt::Module::lower_from_spv_file");
199        let cx = std::rc::Rc::new(spirt::Context::new());
200        crate::custom_insts::register_to_spirt_context(&cx);
201        (
202            spv_words,
203            spirt::Module::lower_from_spv_bytes(cx, spv_bytes),
204            // HACK(eddyb) this is only returned for `SpirtDumpGuard`.
205            lower_from_spv_timer,
206        )
207    };
208
209    // FIXME(eddyb) deduplicate with `SpirtDumpGuard`.
210    let dump_spv_and_spirt = |spv_module: &Module, dump_file_path_stem: PathBuf| {
211        let (spv_words, spirt_module_or_err, _) =
212            spv_module_to_spv_words_and_spirt_module(spv_module);
213        std::fs::write(
214            dump_file_path_stem.with_extension("spv"),
215            spirv_tools::binary::from_binary(&spv_words),
216        )
217        .unwrap();
218
219        // FIXME(eddyb) reify SPIR-V -> SPIR-T errors so they're easier to debug.
220        if let Ok(mut module) = spirt_module_or_err {
221            // HACK(eddyb) avoid pretty-printing massive amounts of unused SPIR-T.
222            spirt::passes::link::minimize_exports(&mut module, |export_key| {
223                matches!(export_key, spirt::ExportKey::SpvEntryPoint { .. })
224            });
225
226            opts.spirt_cleanup_for_dumping(&mut module);
227
228            let pretty = spirt::print::Plan::for_module(&module).pretty_print();
229
230            // FIXME(eddyb) don't allocate whole `String`s here.
231            std::fs::write(
232                dump_file_path_stem.with_extension("spirt"),
233                pretty.to_string(),
234            )
235            .unwrap();
236            std::fs::write(
237                dump_file_path_stem.with_extension("spirt.html"),
238                pretty
239                    .render_to_html()
240                    .with_dark_mode_support()
241                    .to_html_doc(),
242            )
243            .unwrap();
244        }
245    };
246
247    let mut output = {
248        let _timer = sess.timer("link_merge");
249        // shift all the ids
250        let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
251        let version = inputs[0].header.as_ref().unwrap().version();
252
253        for module in inputs.iter_mut().skip(1) {
254            simple_passes::shift_ids(module, bound);
255            bound += module.header.as_ref().unwrap().bound - 1;
256            let this_version = module.header.as_ref().unwrap().version();
257            if version != this_version {
258                return Err(sess.dcx().err(format!(
259                    "cannot link two modules with different SPIR-V versions: v{}.{} and v{}.{}",
260                    version.0, version.1, this_version.0, this_version.1
261                )));
262            }
263        }
264
265        // merge the binaries
266        let mut output = crate::link::with_rspirv_loader(|loader| {
267            for module in inputs {
268                for inst in module.all_inst_iter() {
269                    use rspirv::binary::ParseAction;
270                    match loader.consume_instruction(inst.clone()) {
271                        ParseAction::Continue => {}
272                        ParseAction::Stop => unreachable!(),
273                        ParseAction::Error(err) => return Err(err),
274                    }
275                }
276            }
277            Ok(())
278        })
279        .unwrap();
280
281        let mut header = ModuleHeader::new(bound + 1);
282        header.set_version(version.0, version.1);
283        header.generator = 0x001B_0000;
284        output.header = Some(header);
285        output
286    };
287
288    if let Some(dir) = &opts.dump_post_merge {
289        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
290    }
291
292    // remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
293    {
294        let _timer = sess.timer("link_remove_duplicates");
295        duplicates::remove_duplicate_extensions(&mut output);
296        duplicates::remove_duplicate_capabilities(&mut output);
297        duplicates::remove_duplicate_ext_inst_imports(&mut output);
298        duplicates::remove_duplicate_types(&mut output);
299        // jb-todo: strip identical OpDecoration / OpDecorationGroups
300    }
301
302    // find import / export pairs
303    {
304        let _timer = sess.timer("link_find_pairs");
305        import_export_link::run(opts, sess, &mut output)?;
306    }
307
308    {
309        let _timer = sess.timer("link_dce-post-link");
310        dce::dce(&mut output);
311    }
312
313    {
314        let _timer = sess.timer("link_fragment_inst_check");
315        simple_passes::check_fragment_insts(sess, &output)?;
316    }
317
318    // HACK(eddyb) this has to run before the `report_zombies` pass, so that
319    // any zombies that are passed as call arguments, but eventually unused,
320    // won't be (incorrectly) considered used.
321    {
322        let _timer = sess.timer("link_remove_unused_params");
323        output = param_weakening::remove_unused_params(output);
324    }
325
326    if opts.early_report_zombies {
327        let _timer = sess.timer("link_report_zombies");
328        zombies::report_zombies(sess, &output)?;
329    }
330
331    if opts.infer_storage_classes {
332        let _timer = sess.timer("specialize_generic_storage_class");
333        // HACK(eddyb) `specializer` requires functions' blocks to be in RPO order
334        // (i.e. `block_ordering_pass`) - this could be relaxed by using RPO visit
335        // inside `specializer`, but this is easier.
336        for func in &mut output.functions {
337            simple_passes::block_ordering_pass(func);
338        }
339        output = specializer::specialize(
340            opts,
341            output,
342            specializer::SimpleSpecialization {
343                specialize_operand: |operand| {
344                    matches!(operand, Operand::StorageClass(StorageClass::Generic))
345                },
346
347                // NOTE(eddyb) this can be anything that is guaranteed to pass
348                // validation - there are no constraints so this is either some
349                // unused pointer, or perhaps one created using `OpConstantNull`
350                // and simply never mixed with pointers that have a storage class.
351                // It would be nice to use `Generic` itself here so that we leave
352                // some kind of indication of it being unconstrained, but `Generic`
353                // requires additional capabilities, so we use `Function` instead.
354                // TODO(eddyb) investigate whether this can end up in a pointer
355                // type that's the value of a module-scoped variable, and whether
356                // `Function` is actually invalid! (may need `Private`)
357                concrete_fallback: Operand::StorageClass(StorageClass::Function),
358            },
359        );
360    }
361
362    // NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too!
363    {
364        {
365            let _timer = sess.timer("link_dce-before-inlining");
366            dce::dce(&mut output);
367        }
368
369        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-before-inlining");
370        let mut pointer_to_pointee = FxHashMap::default();
371        let mut constants = FxHashMap::default();
372        let mut u32 = None;
373        for inst in &output.types_global_values {
374            match inst.class.opcode {
375                Op::TypePointer => {
376                    pointer_to_pointee
377                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
378                }
379                Op::TypeInt
380                    if inst.operands[0].unwrap_literal_bit32() == 32
381                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
382                {
383                    assert!(u32.is_none());
384                    u32 = Some(inst.result_id.unwrap());
385                }
386                Op::Constant if u32.is_some() && inst.result_type == u32 => {
387                    let value = inst.operands[0].unwrap_literal_bit32();
388                    constants.insert(inst.result_id.unwrap(), value);
389                }
390                _ => {}
391            }
392        }
393        for func in &mut output.functions {
394            simple_passes::block_ordering_pass(func);
395            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
396            mem2reg::mem2reg(
397                output.header.as_mut().unwrap(),
398                &mut output.types_global_values,
399                &pointer_to_pointee,
400                &constants,
401                func,
402            );
403            destructure_composites::destructure_composites(func);
404        }
405    }
406
407    {
408        let _timer =
409            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-before-inlining");
410        dce::dce(&mut output);
411        duplicates::remove_duplicate_debuginfo(&mut output);
412    }
413
414    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
415    if let Some(dir) = &opts.dump_pre_inline {
416        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
417    }
418
419    {
420        let _timer = sess.timer("link_inline");
421        inline::inline(sess, &mut output)?;
422    }
423
424    // Fold OpLoad from Private/Function variables with constant initializers.
425    // This must run after inlining (which may expose such patterns) and before DCE
426    // (so that the now-unused variables can be removed).
427    // This is critical for pointer-to-pointer patterns like `&&123` which would
428    // otherwise generate invalid SPIR-V in Logical addressing mode.
429    {
430        let _timer = sess.timer("link_fold_load_from_constant_variable");
431        peephole_opts::fold_load_from_constant_variable(&mut output);
432    }
433
434    {
435        let _timer = sess.timer("link_dce-after-inlining");
436        dce::dce(&mut output);
437    }
438
439    // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
440    if let Some(dir) = &opts.dump_post_inline {
441        dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
442    }
443
444    {
445        let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-after-inlining");
446        let mut pointer_to_pointee = FxHashMap::default();
447        let mut constants = FxHashMap::default();
448        let mut u32 = None;
449        for inst in &output.types_global_values {
450            match inst.class.opcode {
451                Op::TypePointer => {
452                    pointer_to_pointee
453                        .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
454                }
455                Op::TypeInt
456                    if inst.operands[0].unwrap_literal_bit32() == 32
457                        && inst.operands[1].unwrap_literal_bit32() == 0 =>
458                {
459                    assert!(u32.is_none());
460                    u32 = Some(inst.result_id.unwrap());
461                }
462                Op::Constant if u32.is_some() && inst.result_type == u32 => {
463                    let value = inst.operands[0].unwrap_literal_bit32();
464                    constants.insert(inst.result_id.unwrap(), value);
465                }
466                _ => {}
467            }
468        }
469        for func in &mut output.functions {
470            simple_passes::block_ordering_pass(func);
471            // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
472            mem2reg::mem2reg(
473                output.header.as_mut().unwrap(),
474                &mut output.types_global_values,
475                &pointer_to_pointee,
476                &constants,
477                func,
478            );
479            destructure_composites::destructure_composites(func);
480        }
481    }
482
483    {
484        let _timer =
485            sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-after-inlining");
486        dce::dce(&mut output);
487        duplicates::remove_duplicate_debuginfo(&mut output);
488    }
489
490    {
491        let _timer = sess.timer("link_remove_non_uniform");
492        simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
493    }
494
495    // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
496    {
497        let (spv_words, module_or_err, lower_from_spv_timer) =
498            spv_module_to_spv_words_and_spirt_module(&output);
499        let module = &mut module_or_err.map_err(|e| {
500            let spv_path = outputs.temp_path_for_diagnostic("spirt-lower-from-spv-input.spv");
501
502            let was_saved_msg =
503                match std::fs::write(&spv_path, spirv_tools::binary::from_binary(&spv_words)) {
504                    Ok(()) => format!("was saved to {}", spv_path.display()),
505                    Err(e) => format!("could not be saved: {e}"),
506                };
507
508            sess.dcx()
509                .struct_err(format!("{e}"))
510                .with_note("while lowering SPIR-V module to SPIR-T (spirt::spv::lower)")
511                .with_note(format!("input SPIR-V module {was_saved_msg}"))
512                .emit()
513        })?;
514
515        let mut dump_guard = SpirtDumpGuard {
516            sess,
517            linker_options: opts,
518            outputs,
519            disambiguated_crate_name_for_dumps,
520
521            module,
522            per_pass_module_for_dumping: vec![],
523            in_progress_pass_name: Cell::new(Some("lower_from_spv")),
524            any_spirt_bugs: false,
525        };
526        let module = &mut *dump_guard.module;
527        // FIXME(eddyb) consider returning a `Drop`-implementing type instead?
528        let before_pass = |pass_name| {
529            let outer_pass_name = dump_guard.in_progress_pass_name.replace(Some(pass_name));
530
531            // FIXME(eddyb) could it make sense to allow these to nest?
532            assert_eq!(outer_pass_name, None);
533
534            sess.timer(pass_name)
535        };
536        let mut after_pass = |module: Option<&spirt::Module>, timer| {
537            drop(timer);
538            let pass_name = dump_guard.in_progress_pass_name.take().unwrap();
539            if let Some(module) = module
540                && matches!(opts.dump_spirt, Some((_, DumpSpirtMode::AllPasses)))
541            {
542                dump_guard
543                    .per_pass_module_for_dumping
544                    .push((pass_name.into(), module.clone()));
545            }
546        };
547        // HACK(eddyb) don't dump the unstructured state if not requested, as
548        // after SPIR-T 0.4.0 it's extremely verbose (due to def-use hermeticity).
549        after_pass(
550            (opts.spirt_keep_unstructured_cfg_in_dumps || !opts.structurize).then_some(&*module),
551            lower_from_spv_timer,
552        );
553
554        // NOTE(eddyb) this *must* run on unstructured CFGs, to do its job.
555        // FIXME(eddyb) no longer relying on structurization, try porting this
556        // to replace custom aborts in `Block`s and inject `ExitInvocation`s
557        // after them (truncating the `Block` and/or parent region if necessary).
558        {
559            let timer = before_pass(
560                "spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points",
561            );
562            spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points(opts, module);
563            after_pass(None, timer);
564        }
565
566        if opts.structurize {
567            let timer = before_pass("spirt::legalize::structurize_func_cfgs");
568            spirt::passes::legalize::structurize_func_cfgs(module);
569            after_pass(Some(module), timer);
570        }
571
572        if !opts.spirt_passes.is_empty() {
573            // FIXME(eddyb) why does this focus on functions, it could just be module passes??
574            spirt_passes::run_func_passes(
575                module,
576                &opts.spirt_passes,
577                |name, _module| before_pass(name),
578                &mut after_pass,
579            );
580        }
581
582        {
583            let timer = before_pass("spirt_passes::explicit_layout::erase_when_invalid");
584            spirt_passes::explicit_layout::erase_when_invalid(module);
585            after_pass(Some(module), timer);
586        }
587
588        {
589            let timer = before_pass("spirt_passes::validate");
590            spirt_passes::validate::validate(module);
591            after_pass(Some(module), timer);
592        }
593
594        {
595            let timer = before_pass("spirt_passes::diagnostics::report_diagnostics");
596            spirt_passes::diagnostics::report_diagnostics(sess, opts, module).map_err(
597                |spirt_passes::diagnostics::ReportedDiagnostics {
598                     rustc_errors_guarantee,
599                     any_errors_were_spirt_bugs,
600                 }| {
601                    dump_guard.any_spirt_bugs |= any_errors_were_spirt_bugs;
602                    rustc_errors_guarantee
603                },
604            )?;
605            after_pass(None, timer);
606        }
607
608        // Replace our custom debuginfo instructions just before lifting to SPIR-V.
609        {
610            let timer = before_pass("spirt_passes::debuginfo::convert_custom_debuginfo_to_spv");
611            spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
612            after_pass(None, timer);
613        }
614
615        let spv_words = {
616            let timer = before_pass("spirt::Module::lift_to_spv_module_emitter");
617            let spv_words = module.lift_to_spv_module_emitter().unwrap().words;
618            after_pass(None, timer);
619            spv_words
620        };
621        // FIXME(eddyb) dump both SPIR-T and `spv_words` if there's an error here.
622        output = {
623            let _timer = sess.timer("parse-spv_words-from-spirt");
624            crate::link::with_rspirv_loader(|loader| {
625                rspirv::binary::parse_words(&spv_words, loader)
626            })
627            .unwrap()
628        };
629    }
630
631    // Ensure that no references remain, to our custom "extended instruction set".
632    for inst in &output.ext_inst_imports {
633        assert_eq!(inst.class.opcode, Op::ExtInstImport);
634        let ext_inst_set = inst.operands[0].unwrap_literal_string();
635        if ext_inst_set.starts_with(custom_insts::CUSTOM_EXT_INST_SET_PREFIX) {
636            let expected = &custom_insts::CUSTOM_EXT_INST_SET[..];
637            if ext_inst_set == expected {
638                return Err(sess.dcx().err(format!(
639                    "`OpExtInstImport {ext_inst_set:?}` should not have been \
640                         left around after SPIR-T passes"
641                )));
642            } else {
643                return Err(sess.dcx().err(format!(
644                    "unsupported `OpExtInstImport {ext_inst_set:?}`
645                     (expected {expected:?} name - version mismatch?)"
646                )));
647            }
648        }
649    }
650
651    // FIXME(eddyb) rewrite these passes to SPIR-T ones, so we don't have to
652    // parse the output of `spirt::spv::lift` back into `rspirv` - also, for
653    // multi-module, it's much simpler with SPIR-T, just replace `module.exports`
654    // with a single-entry map, run `spirt::spv::lift` (or even `spirt::print`)
655    // on `module`, then put back the full original `module.exports` map.
656    {
657        let _timer = sess.timer("peephole_opts");
658        let types = peephole_opts::collect_types(&output);
659        for func in &mut output.functions {
660            peephole_opts::composite_construct(&types, func);
661            peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
662            peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
663        }
664    }
665
666    {
667        let _timer = sess.timer("link_remove_unused_type_capabilities");
668        simple_passes::remove_unused_type_capabilities(&mut output);
669    }
670
671    {
672        let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
673        entry_interface::gather_all_interface_vars_from_uses(&mut output, opts.preserve_bindings);
674    }
675
676    if opts.spirv_metadata == SpirvMetadata::NameVariables {
677        let _timer = sess.timer("link_name_variables");
678        simple_passes::name_variables_pass(&mut output);
679    }
680
681    {
682        let _timer = sess.timer("link_sort_globals");
683        simple_passes::sort_globals(&mut output);
684    }
685
686    let mut output = if opts.module_output_type == ModuleOutputType::Multiple {
687        let mut file_stem_to_entry_name_and_module = BTreeMap::new();
688        for (i, entry) in output.entry_points.iter().enumerate() {
689            let mut module = output.clone();
690            module.entry_points.clear();
691            module.entry_points.push(entry.clone());
692            let entry_name = entry.operands[2].unwrap_literal_string().to_string();
693            let mut file_stem = OsString::from(
694                sanitize_filename::sanitize_with_options(
695                    &entry_name,
696                    sanitize_filename::Options {
697                        replacement: "-",
698                        ..Default::default()
699                    },
700                )
701                .replace("--", "-"),
702            );
703            // It's always possible to find an unambiguous `file_stem`, but it
704            // may take two tries (or more, in bizzare/adversarial cases).
705            let mut disambiguator = Some(i);
706            loop {
707                use std::collections::btree_map::Entry;
708                match file_stem_to_entry_name_and_module.entry(file_stem) {
709                    Entry::Vacant(entry) => {
710                        entry.insert((entry_name, module));
711                        break;
712                    }
713                    // FIXME(eddyb) false positive: `file_stem` was moved out of,
714                    // so assigning it is necessary, but clippy doesn't know that.
715                    #[allow(clippy::assigning_clones)]
716                    Entry::Occupied(entry) => {
717                        // FIXME(eddyb) there's no way to access the owned key
718                        // passed to `BTreeMap::entry` from `OccupiedEntry`.
719                        file_stem = entry.key().clone();
720                        file_stem.push(".");
721                        match disambiguator.take() {
722                            Some(d) => file_stem.push(d.to_string()),
723                            None => file_stem.push("next"),
724                        }
725                    }
726                }
727            }
728        }
729        LinkResult::MultipleModules {
730            file_stem_to_entry_name_and_module,
731        }
732    } else {
733        LinkResult::SingleModule(Box::new(output))
734    };
735
736    let output_module_iter = match &mut output {
737        LinkResult::SingleModule(m) => Either::Left(std::iter::once((None, &mut **m))),
738        LinkResult::MultipleModules {
739            file_stem_to_entry_name_and_module,
740        } => Either::Right(
741            file_stem_to_entry_name_and_module
742                .iter_mut()
743                .map(|(file_stem, (_, m))| (Some(file_stem), m)),
744        ),
745    };
746    for (file_stem, output) in output_module_iter {
747        // Run DCE again, even if module_output_type == ModuleOutputType::Multiple - the first DCE ran before
748        // structurization and mem2reg (for perf reasons), and mem2reg may remove references to
749        // invalid types, so we need to DCE again.
750        {
751            let _timer = sess.timer("link_dce-post-split");
752            dce::dce(output);
753        }
754
755        // HACK(eddyb) this has to be after DCE, to not break SPIR-T w/ dead decorations.
756        if let Some(dir) = &opts.dump_post_split {
757            let mut file_name = disambiguated_crate_name_for_dumps.to_os_string();
758            if let Some(file_stem) = file_stem {
759                file_name.push(".");
760                file_name.push(file_stem);
761            }
762
763            dump_spv_and_spirt(output, dir.join(file_name));
764        }
765
766        {
767            let _timer = sess.timer("link_remove_duplicate_debuginfo");
768            duplicates::remove_duplicate_debuginfo(output);
769        }
770
771        if opts.compact_ids {
772            let _timer = sess.timer("link_compact_ids");
773            // compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43
774            output.header.as_mut().unwrap().bound = simple_passes::compact_ids(output);
775        };
776
777        // FIXME(eddyb) convert these into actual `OpLine`s with a SPIR-T pass,
778        // but that'd require keeping the modules in SPIR-T form (once lowered),
779        // and never loading them back into `rspirv` once lifted back to SPIR-V.
780        SrcLocDecoration::remove_all(output);
781
782        // FIXME(eddyb) might make more sense to rewrite these away on SPIR-T.
783        ZombieDecoration::remove_all(output);
784    }
785
786    Ok(output)
787}
788
789/// Helper for dumping SPIR-T on drop, which allows panics to also dump,
790/// not just successful compilation (i.e. via `--dump-spirt-passes`).
791struct SpirtDumpGuard<'a> {
792    sess: &'a Session,
793    linker_options: &'a Options,
794    outputs: &'a OutputFilenames,
795    disambiguated_crate_name_for_dumps: &'a OsStr,
796
797    module: &'a mut spirt::Module,
798    per_pass_module_for_dumping: Vec<(Cow<'static, str>, spirt::Module)>,
799    in_progress_pass_name: Cell<Option<&'static str>>,
800    any_spirt_bugs: bool,
801}
802
803impl Drop for SpirtDumpGuard<'_> {
804    fn drop(&mut self) {
805        if std::thread::panicking() {
806            self.any_spirt_bugs = true;
807
808            // HACK(eddyb) the active pass panicked, make sure to include its
809            // (potentially corrupted) state, which will hopefully be printed
810            // later below (with protection against panicking during printing).
811            if let Some(pass_name) = self.in_progress_pass_name.get() {
812                self.per_pass_module_for_dumping.push((
813                    format!("{pass_name} [PANICKED]").into(),
814                    self.module.clone(),
815                ));
816            }
817        }
818
819        let mut dump_spirt_file_path =
820            self.linker_options
821                .dump_spirt
822                .as_ref()
823                .map(|(dump_dir, _)| {
824                    dump_dir
825                        .join(self.disambiguated_crate_name_for_dumps)
826                        .with_extension("spirt")
827                });
828
829        // FIXME(eddyb) this won't allow seeing the individual passes, but it's
830        // better than nothing (theoretically the whole "SPIR-T pipeline" could
831        // be put in a loop so that everything is redone with per-pass tracking,
832        // but that requires keeping around e.g. the initial SPIR-V for longer,
833        // and probably invoking the "SPIR-T pipeline" here, as looping is hard).
834        if self.any_spirt_bugs && dump_spirt_file_path.is_none() {
835            dump_spirt_file_path = Some(self.outputs.temp_path_for_diagnostic("spirt"));
836        }
837
838        let Some(dump_spirt_file_path) = &dump_spirt_file_path else {
839            return;
840        };
841
842        if self.per_pass_module_for_dumping.is_empty() {
843            self.per_pass_module_for_dumping
844                .push(("".into(), self.module.clone()));
845        }
846
847        for (_, module) in &mut self.per_pass_module_for_dumping {
848            // FIXME(eddyb) consider catching panics in this?
849            self.linker_options.spirt_cleanup_for_dumping(module);
850        }
851
852        let cx = self.module.cx();
853        let versions = self
854            .per_pass_module_for_dumping
855            .iter()
856            .map(|(pass_name, module)| {
857                (
858                    if pass_name.is_empty() {
859                        "".into()
860                    } else {
861                        format!("after {pass_name}")
862                    },
863                    module,
864                )
865            });
866
867        let mut panicked_printing_after_passes = None;
868        for truncate_version_count in (1..=versions.len()).rev() {
869            // FIXME(eddyb) tell the user to use `--dump-spirt-passes` if that
870            // wasn't active but a panic happens - on top of that, it may need
871            // quieting the panic handler, likely controlled by a `thread_local!`
872            // (while the panic handler is global), and that would also be useful
873            // for collecting a panic message (assuming any of this is worth it).
874            // HACK(eddyb) for now, keeping the panic handler works out, as the
875            // panic messages would at least be seen by the user.
876            let printed_or_panicked =
877                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
878                    let pretty = spirt::print::Plan::for_versions(
879                        &cx,
880                        versions.clone().take(truncate_version_count),
881                    )
882                    .pretty_print();
883
884                    // FIXME(eddyb) don't allocate whole `String`s here.
885                    std::fs::write(dump_spirt_file_path, pretty.to_string()).unwrap();
886                    std::fs::write(
887                        dump_spirt_file_path.with_extension("spirt.html"),
888                        pretty
889                            .render_to_html()
890                            .with_dark_mode_support()
891                            .to_html_doc(),
892                    )
893                    .unwrap();
894                }));
895            match printed_or_panicked {
896                Ok(()) => {
897                    if truncate_version_count != versions.len() {
898                        panicked_printing_after_passes = Some(
899                            self.per_pass_module_for_dumping[truncate_version_count..]
900                                .iter()
901                                .map(|(pass_name, _)| format!("`{pass_name}`"))
902                                .collect::<Vec<_>>()
903                                .join(", "),
904                        );
905                    }
906                    break;
907                }
908                Err(panic) => {
909                    if truncate_version_count == 1 {
910                        std::panic::resume_unwind(panic);
911                    }
912                }
913            }
914        }
915        if self.any_spirt_bugs || panicked_printing_after_passes.is_some() {
916            let mut note = self.sess.dcx().struct_note("SPIR-T bugs were encountered");
917            if let Some(pass_names) = panicked_printing_after_passes {
918                note.warn(format!(
919                    "SPIR-T pretty-printing panicked after: {pass_names}"
920                ));
921            }
922            note.help(format!(
923                "pretty-printed SPIR-T was saved to {}.html",
924                dump_spirt_file_path.display()
925            ));
926            let is_dumping_spirt_passes = matches!(
927                self.linker_options.dump_spirt,
928                Some((_, DumpSpirtMode::AllPasses))
929            );
930            if !is_dumping_spirt_passes {
931                note.help("re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` for more details");
932            }
933            note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues");
934            note.emit();
935        }
936    }
937}