#[cfg(test)]
mod test;
mod dce;
mod destructure_composites;
mod duplicates;
mod entry_interface;
mod import_export_link;
mod inline;
mod ipo;
mod mem2reg;
mod param_weakening;
mod peephole_opts;
mod simple_passes;
mod specializer;
mod spirt_passes;
mod zombies;
use std::borrow::Cow;
use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
use crate::custom_insts;
use either::Either;
use rspirv::binary::{Assemble, Consumer};
use rspirv::dr::{Block, Loader, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::ErrorGuaranteed;
use rustc_session::Session;
use rustc_session::config::OutputFilenames;
use std::collections::BTreeMap;
use std::ffi::{OsStr, OsString};
use std::path::PathBuf;
pub type Result<T> = std::result::Result<T, ErrorGuaranteed>;
#[derive(Default)]
pub struct Options {
pub compact_ids: bool,
pub dce: bool,
pub early_report_zombies: bool,
pub infer_storage_classes: bool,
pub structurize: bool,
pub spirt_passes: Vec<String>,
pub abort_strategy: Option<String>,
pub module_output_type: ModuleOutputType,
pub spirv_metadata: SpirvMetadata,
pub keep_link_exports: bool,
pub dump_post_merge: Option<PathBuf>,
pub dump_post_split: Option<PathBuf>,
pub dump_spirt_passes: Option<PathBuf>,
pub spirt_strip_custom_debuginfo_from_dumps: bool,
pub spirt_keep_debug_sources_in_dumps: bool,
pub spirt_keep_unstructured_cfg_in_dumps: bool,
pub specializer_debug: bool,
pub specializer_dump_instances: Option<PathBuf>,
pub print_all_zombie: bool,
pub print_zombie: bool,
}
pub enum LinkResult {
SingleModule(Box<Module>),
MultipleModules {
file_stem_to_entry_name_and_module: BTreeMap<OsString, (String, Module)>,
},
}
fn id(header: &mut ModuleHeader) -> Word {
let result = header.bound;
header.bound += 1;
result
}
fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
let all_ids_mut = blocks
.iter_mut()
.flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
.flat_map(|inst| {
inst.result_id
.iter_mut()
.chain(inst.result_type.iter_mut())
.chain(
inst.operands
.iter_mut()
.filter_map(|op| op.id_ref_any_mut()),
)
});
for id in all_ids_mut {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
}
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
let entry_names = module
.entry_points
.iter()
.filter(|i| i.class.opcode == Op::EntryPoint)
.map(|i| {
(
i.operands[1].unwrap_id_ref(),
i.operands[2].unwrap_literal_string(),
)
});
let debug_names = module
.debug_names
.iter()
.filter(|i| i.class.opcode == Op::Name)
.map(|i| {
(
i.operands[0].unwrap_id_ref(),
i.operands[1].unwrap_literal_string(),
)
});
entry_names.chain(debug_names).collect()
}
fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
names.get(&id).map_or_else(
|| Cow::Owned(format!("Unnamed function ID %{id}")),
|&s| Cow::Borrowed(s),
)
}
impl Options {
fn spirt_cleanup_for_dumping(&self, module: &mut spirt::Module) {
if self.spirt_strip_custom_debuginfo_from_dumps {
spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
}
if !self.spirt_keep_debug_sources_in_dumps {
const DOTS: &str = "⋯";
let dots_interned_str = module.cx().intern(DOTS);
let spirt::ModuleDebugInfo::Spv(debuginfo) = &mut module.debug_info;
for sources in debuginfo.source_languages.values_mut() {
for file in sources.file_contents.values_mut() {
*file = DOTS.into();
}
sources.file_contents.insert(
dots_interned_str,
"sources hidden, to show them use \
`RUSTGPU_CODEGEN_ARGS=--spirt-keep-debug-sources-in-dumps`"
.into(),
);
}
}
}
}
pub fn link(
sess: &Session,
mut inputs: Vec<Module>,
opts: &Options,
outputs: &OutputFilenames,
disambiguated_crate_name_for_dumps: &OsStr,
) -> Result<LinkResult> {
let spv_module_to_spv_words_and_spirt_module = |spv_module: &Module| {
let spv_words;
let spv_bytes = {
let _timer = sess.timer("assemble-to-spv_bytes-for-spirt");
spv_words = spv_module.assemble();
spirv_tools::binary::from_binary(&spv_words).to_vec()
};
let lower_from_spv_timer = sess.timer("spirt::Module::lower_from_spv_file");
let cx = std::rc::Rc::new(spirt::Context::new());
crate::custom_insts::register_to_spirt_context(&cx);
(
spv_words,
spirt::Module::lower_from_spv_bytes(cx, spv_bytes),
lower_from_spv_timer,
)
};
let dump_spv_and_spirt = |spv_module: &Module, dump_file_path_stem: PathBuf| {
let (spv_words, spirt_module_or_err, _) =
spv_module_to_spv_words_and_spirt_module(spv_module);
std::fs::write(
dump_file_path_stem.with_extension("spv"),
spirv_tools::binary::from_binary(&spv_words),
)
.unwrap();
if let Ok(mut module) = spirt_module_or_err {
spirt::passes::link::minimize_exports(&mut module, |export_key| {
matches!(export_key, spirt::ExportKey::SpvEntryPoint { .. })
});
opts.spirt_cleanup_for_dumping(&mut module);
let pretty = spirt::print::Plan::for_module(&module).pretty_print();
std::fs::write(
dump_file_path_stem.with_extension("spirt"),
pretty.to_string(),
)
.unwrap();
std::fs::write(
dump_file_path_stem.with_extension("spirt.html"),
pretty
.render_to_html()
.with_dark_mode_support()
.to_html_doc(),
)
.unwrap();
}
};
let mut output = {
let _timer = sess.timer("link_merge");
let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
let version = inputs[0].header.as_ref().unwrap().version();
for module in inputs.iter_mut().skip(1) {
simple_passes::shift_ids(module, bound);
bound += module.header.as_ref().unwrap().bound - 1;
let this_version = module.header.as_ref().unwrap().version();
if version != this_version {
return Err(sess.dcx().err(format!(
"cannot link two modules with different SPIR-V versions: v{}.{} and v{}.{}",
version.0, version.1, this_version.0, this_version.1
)));
}
}
let mut loader = Loader::new();
for module in inputs {
module.all_inst_iter().for_each(|inst| {
loader.consume_instruction(inst.clone());
});
}
let mut output = loader.module();
let mut header = ModuleHeader::new(bound + 1);
header.set_version(version.0, version.1);
header.generator = 0x001B_0000;
output.header = Some(header);
output
};
if let Some(dir) = &opts.dump_post_merge {
dump_spv_and_spirt(&output, dir.join(disambiguated_crate_name_for_dumps));
}
{
let _timer = sess.timer("link_remove_duplicates");
duplicates::remove_duplicate_extensions(&mut output);
duplicates::remove_duplicate_capabilities(&mut output);
duplicates::remove_duplicate_ext_inst_imports(&mut output);
duplicates::remove_duplicate_types(&mut output);
}
{
let _timer = sess.timer("link_find_pairs");
import_export_link::run(opts, sess, &mut output)?;
}
{
let _timer = sess.timer("link_fragment_inst_check");
simple_passes::check_fragment_insts(sess, &output)?;
}
{
let _timer = sess.timer("link_remove_unused_params");
output = param_weakening::remove_unused_params(output);
}
if opts.early_report_zombies {
{
let _timer = sess.timer("link_block_ordering_pass-before-report_and_remove_zombies");
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
}
}
let _timer = sess.timer("link_report_and_remove_zombies");
zombies::report_and_remove_zombies(sess, opts, &mut output)?;
}
if opts.infer_storage_classes {
if !opts.early_report_zombies {
let _timer = sess.timer("link_dce-before-specialize_generic_storage_class");
dce::dce(&mut output);
}
let _timer = sess.timer("specialize_generic_storage_class");
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
}
output = specializer::specialize(opts, output, specializer::SimpleSpecialization {
specialize_operand: |operand| {
matches!(operand, Operand::StorageClass(StorageClass::Generic))
},
concrete_fallback: Operand::StorageClass(StorageClass::Function),
});
}
{
if opts.dce {
let _timer = sess.timer("link_dce-before-inlining");
dce::dce(&mut output);
}
let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-before-inlining");
let mut pointer_to_pointee = FxHashMap::default();
let mut constants = FxHashMap::default();
let mut u32 = None;
for inst in &output.types_global_values {
match inst.class.opcode {
Op::TypePointer => {
pointer_to_pointee
.insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
}
Op::TypeInt
if inst.operands[0].unwrap_literal_bit32() == 32
&& inst.operands[1].unwrap_literal_bit32() == 0 =>
{
assert!(u32.is_none());
u32 = Some(inst.result_id.unwrap());
}
Op::Constant if u32.is_some() && inst.result_type == u32 => {
let value = inst.operands[0].unwrap_literal_bit32();
constants.insert(inst.result_id.unwrap(), value);
}
_ => {}
}
}
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
mem2reg::mem2reg(
output.header.as_mut().unwrap(),
&mut output.types_global_values,
&pointer_to_pointee,
&constants,
func,
);
destructure_composites::destructure_composites(func);
}
}
if opts.dce {
let _timer =
sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-before-inlining");
dce::dce(&mut output);
duplicates::remove_duplicate_debuginfo(&mut output);
}
{
let _timer = sess.timer("link_inline");
inline::inline(sess, &mut output)?;
}
if opts.dce {
let _timer = sess.timer("link_dce-after-inlining");
dce::dce(&mut output);
}
{
let _timer = sess.timer("link_block_ordering_pass_and_mem2reg-after-inlining");
let mut pointer_to_pointee = FxHashMap::default();
let mut constants = FxHashMap::default();
let mut u32 = None;
for inst in &output.types_global_values {
match inst.class.opcode {
Op::TypePointer => {
pointer_to_pointee
.insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref());
}
Op::TypeInt
if inst.operands[0].unwrap_literal_bit32() == 32
&& inst.operands[1].unwrap_literal_bit32() == 0 =>
{
assert!(u32.is_none());
u32 = Some(inst.result_id.unwrap());
}
Op::Constant if u32.is_some() && inst.result_type == u32 => {
let value = inst.operands[0].unwrap_literal_bit32();
constants.insert(inst.result_id.unwrap(), value);
}
_ => {}
}
}
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
mem2reg::mem2reg(
output.header.as_mut().unwrap(),
&mut output.types_global_values,
&pointer_to_pointee,
&constants,
func,
);
destructure_composites::destructure_composites(func);
}
}
if opts.dce {
let _timer =
sess.timer("link_dce-and-remove_duplicate_debuginfo-after-mem2reg-after-inlining");
dce::dce(&mut output);
duplicates::remove_duplicate_debuginfo(&mut output);
}
{
let (spv_words, module_or_err, lower_from_spv_timer) =
spv_module_to_spv_words_and_spirt_module(&output);
let module = &mut module_or_err.map_err(|e| {
let spv_path = outputs.temp_path_ext("spirt-lower-from-spv-input.spv", None);
let was_saved_msg =
match std::fs::write(&spv_path, spirv_tools::binary::from_binary(&spv_words)) {
Ok(()) => format!("was saved to {}", spv_path.display()),
Err(e) => format!("could not be saved: {e}"),
};
sess.dcx()
.struct_err(format!("{e}"))
.with_note("while lowering SPIR-V module to SPIR-T (spirt::spv::lower)")
.with_note(format!("input SPIR-V module {was_saved_msg}"))
.emit()
})?;
let mut dump_guard = SpirtDumpGuard {
sess,
linker_options: opts,
outputs,
disambiguated_crate_name_for_dumps,
module,
per_pass_module_for_dumping: vec![],
any_spirt_bugs: false,
};
let module = &mut *dump_guard.module;
let before_pass = |pass| sess.timer(pass);
let mut after_pass = |pass, module: &spirt::Module, timer| {
drop(timer);
if opts.dump_spirt_passes.is_some() {
dump_guard
.per_pass_module_for_dumping
.push((pass, module.clone()));
}
};
if opts.spirt_keep_unstructured_cfg_in_dumps || !opts.structurize {
after_pass("lower_from_spv", module, lower_from_spv_timer);
} else {
drop(lower_from_spv_timer);
}
{
let _timer = before_pass(
"spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points",
);
spirt_passes::controlflow::convert_custom_aborts_to_unstructured_returns_in_entry_points(opts, module);
}
if opts.structurize {
let timer = before_pass("spirt::legalize::structurize_func_cfgs");
spirt::passes::legalize::structurize_func_cfgs(module);
after_pass("structurize_func_cfgs", module, timer);
}
if !opts.spirt_passes.is_empty() {
spirt_passes::run_func_passes(
module,
&opts.spirt_passes,
|name, _module| before_pass(name),
after_pass,
);
}
{
let _timer = before_pass("spirt_passes::diagnostics::report_diagnostics");
spirt_passes::diagnostics::report_diagnostics(sess, opts, module).map_err(
|spirt_passes::diagnostics::ReportedDiagnostics {
rustc_errors_guarantee,
any_errors_were_spirt_bugs,
}| {
dump_guard.any_spirt_bugs |= any_errors_were_spirt_bugs;
rustc_errors_guarantee
},
)?;
}
{
let _timer = before_pass("spirt_passes::debuginfo::convert_custom_debuginfo_to_spv");
spirt_passes::debuginfo::convert_custom_debuginfo_to_spv(module);
}
let spv_words = {
let _timer = before_pass("spirt::Module::lift_to_spv_module_emitter");
module.lift_to_spv_module_emitter().unwrap().words
};
output = {
let _timer = sess.timer("parse-spv_words-from-spirt");
let mut loader = Loader::new();
rspirv::binary::parse_words(&spv_words, &mut loader).unwrap();
loader.module()
};
}
for inst in &output.ext_inst_imports {
assert_eq!(inst.class.opcode, Op::ExtInstImport);
let ext_inst_set = inst.operands[0].unwrap_literal_string();
if ext_inst_set.starts_with(custom_insts::CUSTOM_EXT_INST_SET_PREFIX) {
let expected = &custom_insts::CUSTOM_EXT_INST_SET[..];
if ext_inst_set == expected {
return Err(sess.dcx().err(format!(
"`OpExtInstImport {ext_inst_set:?}` should not have been \
left around after SPIR-T passes"
)));
} else {
return Err(sess.dcx().err(format!(
"unsupported `OpExtInstImport {ext_inst_set:?}`
(expected {expected:?} name - version mismatch?)"
)));
}
}
}
{
let _timer = sess.timer("peephole_opts");
let types = peephole_opts::collect_types(&output);
for func in &mut output.functions {
peephole_opts::composite_construct(&types, func);
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
}
}
{
let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
entry_interface::gather_all_interface_vars_from_uses(&mut output);
}
if opts.spirv_metadata == SpirvMetadata::NameVariables {
let _timer = sess.timer("link_name_variables");
simple_passes::name_variables_pass(&mut output);
}
{
let _timer = sess.timer("link_sort_globals");
simple_passes::sort_globals(&mut output);
}
let mut output = if opts.module_output_type == ModuleOutputType::Multiple {
let mut file_stem_to_entry_name_and_module = BTreeMap::new();
for (i, entry) in output.entry_points.iter().enumerate() {
let mut module = output.clone();
module.entry_points.clear();
module.entry_points.push(entry.clone());
let entry_name = entry.operands[2].unwrap_literal_string().to_string();
let mut file_stem = OsString::from(
sanitize_filename::sanitize_with_options(&entry_name, sanitize_filename::Options {
replacement: "-",
..Default::default()
})
.replace("--", "-"),
);
let mut disambiguator = Some(i);
loop {
use std::collections::btree_map::Entry;
match file_stem_to_entry_name_and_module.entry(file_stem) {
Entry::Vacant(entry) => {
entry.insert((entry_name, module));
break;
}
#[allow(clippy::assigning_clones)]
Entry::Occupied(entry) => {
file_stem = entry.key().clone();
file_stem.push(".");
match disambiguator.take() {
Some(d) => file_stem.push(d.to_string()),
None => file_stem.push("next"),
}
}
}
}
}
LinkResult::MultipleModules {
file_stem_to_entry_name_and_module,
}
} else {
LinkResult::SingleModule(Box::new(output))
};
let output_module_iter = match &mut output {
LinkResult::SingleModule(m) => Either::Left(std::iter::once((None, &mut **m))),
LinkResult::MultipleModules {
file_stem_to_entry_name_and_module,
} => Either::Right(
file_stem_to_entry_name_and_module
.iter_mut()
.map(|(file_stem, (_, m))| (Some(file_stem), m)),
),
};
for (file_stem, output) in output_module_iter {
if let Some(dir) = &opts.dump_post_split {
let mut file_name = disambiguated_crate_name_for_dumps.to_os_string();
if let Some(file_stem) = file_stem {
file_name.push(".");
file_name.push(file_stem);
}
dump_spv_and_spirt(output, dir.join(file_name));
}
if opts.dce {
let _timer = sess.timer("link_dce_2");
dce::dce(output);
}
{
let _timer = sess.timer("link_remove_duplicate_debuginfo");
duplicates::remove_duplicate_debuginfo(output);
}
if opts.compact_ids {
let _timer = sess.timer("link_compact_ids");
output.header.as_mut().unwrap().bound = simple_passes::compact_ids(output);
};
SrcLocDecoration::remove_all(output);
ZombieDecoration::remove_all(output);
}
Ok(output)
}
struct SpirtDumpGuard<'a> {
sess: &'a Session,
linker_options: &'a Options,
outputs: &'a OutputFilenames,
disambiguated_crate_name_for_dumps: &'a OsStr,
module: &'a mut spirt::Module,
per_pass_module_for_dumping: Vec<(&'static str, spirt::Module)>,
any_spirt_bugs: bool,
}
impl Drop for SpirtDumpGuard<'_> {
fn drop(&mut self) {
self.any_spirt_bugs |= std::thread::panicking();
let mut dump_spirt_file_path =
self.linker_options
.dump_spirt_passes
.as_ref()
.map(|dump_dir| {
dump_dir
.join(self.disambiguated_crate_name_for_dumps)
.with_extension("spirt")
});
if self.any_spirt_bugs && dump_spirt_file_path.is_none() {
if self.per_pass_module_for_dumping.is_empty() {
self.per_pass_module_for_dumping
.push(("", self.module.clone()));
}
dump_spirt_file_path = Some(self.outputs.temp_path_ext("spirt", None));
}
if let Some(dump_spirt_file_path) = &dump_spirt_file_path {
for (_, module) in &mut self.per_pass_module_for_dumping {
self.linker_options.spirt_cleanup_for_dumping(module);
}
let plan = spirt::print::Plan::for_versions(
self.module.cx_ref(),
self.per_pass_module_for_dumping
.iter()
.map(|(pass, module)| (format!("after {pass}"), module)),
);
let pretty = plan.pretty_print();
std::fs::write(dump_spirt_file_path, pretty.to_string()).unwrap();
std::fs::write(
dump_spirt_file_path.with_extension("spirt.html"),
pretty
.render_to_html()
.with_dark_mode_support()
.to_html_doc(),
)
.unwrap();
if self.any_spirt_bugs {
let mut note = self.sess.dcx().struct_note("SPIR-T bugs were encountered");
note.help(format!(
"pretty-printed SPIR-T was saved to {}.html",
dump_spirt_file_path.display()
));
if self.linker_options.dump_spirt_passes.is_none() {
note.help("re-run with `RUSTGPU_CODEGEN_ARGS=\"--dump-spirt-passes=$PWD\"` for more details");
}
note.note("pretty-printed SPIR-T is preferred when reporting Rust-GPU issues");
note.emit();
}
}
}
}