rustc_codegen_spirv/linker/spirt_passes/
mod.rs

1//! SPIR-T pass infrastructure and supporting utilities.
2
3pub(crate) mod controlflow;
4pub(crate) mod debuginfo;
5pub(crate) mod diagnostics;
6mod fuse_selects;
7mod reduce;
8pub(crate) mod validate;
9
10use lazy_static::lazy_static;
11use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
12use spirt::func_at::FuncAt;
13use spirt::transform::InnerInPlaceTransform;
14use spirt::visit::{InnerVisit, Visitor};
15use spirt::{
16    AttrSet, Const, Context, ControlNode, ControlNodeKind, ControlRegion, DataInstDef,
17    DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityOrientedDenseMap, Func,
18    FuncDefBody, GlobalVar, Module, Type, Value, spv,
19};
20use std::collections::VecDeque;
21use std::iter;
22
23// HACK(eddyb) `spv::spec::Spec` with extra `WellKnown`s (that should be upstreamed).
24macro_rules! def_spv_spec_with_extra_well_known {
25    ($($group:ident: $ty:ty = [$($entry:ident),+ $(,)?]),+ $(,)?) => {
26        struct SpvSpecWithExtras {
27            __base_spec: &'static spv::spec::Spec,
28
29            well_known: SpvWellKnownWithExtras,
30        }
31
32        #[allow(non_snake_case)]
33        pub struct SpvWellKnownWithExtras {
34            __base_well_known: &'static spv::spec::WellKnown,
35
36            $($(pub $entry: $ty,)+)+
37        }
38
39        impl std::ops::Deref for SpvSpecWithExtras {
40            type Target = spv::spec::Spec;
41            fn deref(&self) -> &Self::Target {
42                self.__base_spec
43            }
44        }
45
46        impl std::ops::Deref for SpvWellKnownWithExtras {
47            type Target = spv::spec::WellKnown;
48            fn deref(&self) -> &Self::Target {
49                self.__base_well_known
50            }
51        }
52
53        impl SpvSpecWithExtras {
54            #[inline(always)]
55            #[must_use]
56            pub fn get() -> &'static SpvSpecWithExtras {
57                lazy_static! {
58                    static ref SPEC: SpvSpecWithExtras = {
59                        #[allow(non_camel_case_types)]
60                        struct PerWellKnownGroup<$($group),+> {
61                            $($group: $group),+
62                        }
63
64                        let spv_spec = spv::spec::Spec::get();
65                        let wk = &spv_spec.well_known;
66
67                        let decorations = match wk.Decoration.def() {
68                            spv::spec::OperandKindDef::ValueEnum { variants } => variants,
69                            _ => unreachable!(),
70                        };
71
72                        let lookup_fns = PerWellKnownGroup {
73                            opcode: |name| spv_spec.instructions.lookup(name).unwrap(),
74                            operand_kind: |name| spv_spec.operand_kinds.lookup(name).unwrap(),
75                            decoration: |name| decorations.lookup(name).unwrap().into(),
76                        };
77
78                        SpvSpecWithExtras {
79                            __base_spec: spv_spec,
80
81                            well_known: SpvWellKnownWithExtras {
82                                __base_well_known: &spv_spec.well_known,
83
84                                $($($entry: (lookup_fns.$group)(stringify!($entry)),)+)+
85                            },
86                        }
87                    };
88                }
89                &SPEC
90            }
91        }
92    };
93}
94def_spv_spec_with_extra_well_known! {
95    opcode: spv::spec::Opcode = [
96        OpTypeVoid,
97
98        OpConstantComposite,
99
100        OpBitcast,
101        OpCompositeInsert,
102        OpCompositeExtract,
103    ],
104    operand_kind: spv::spec::OperandKind = [
105        Capability,
106        ExecutionModel,
107        ImageFormat,
108    ],
109    decoration: u32 = [
110        UserTypeGOOGLE,
111    ],
112}
113
114/// Run intra-function passes on all `Func` definitions in the `Module`.
115//
116// FIXME(eddyb) introduce a proper "pass manager".
117// FIXME(eddyb) why does this focus on functions, it could just be module passes??
118pub(super) fn run_func_passes<P>(
119    module: &mut Module,
120    passes: &[impl AsRef<str>],
121    // FIXME(eddyb) this is a very poor approximation of a "profiler" abstraction.
122    mut before_pass: impl FnMut(&'static str, &Module) -> P,
123    mut after_pass: impl FnMut(Option<&Module>, P),
124) {
125    let cx = &module.cx();
126
127    // FIXME(eddyb) reuse this collection work in some kind of "pass manager".
128    let all_funcs = {
129        let mut collector = ReachableUseCollector {
130            cx,
131            module,
132
133            seen_types: FxIndexSet::default(),
134            seen_consts: FxIndexSet::default(),
135            seen_data_inst_forms: FxIndexSet::default(),
136            seen_global_vars: FxIndexSet::default(),
137            seen_funcs: FxIndexSet::default(),
138        };
139        for (export_key, &exportee) in &module.exports {
140            export_key.inner_visit_with(&mut collector);
141            exportee.inner_visit_with(&mut collector);
142        }
143        collector.seen_funcs
144    };
145
146    for name in passes {
147        let name = name.as_ref();
148
149        // HACK(eddyb) not really a function pass.
150        if name == "qptr" {
151            let layout_config = &spirt::qptr::LayoutConfig {
152                abstract_bool_size_align: (1, 1),
153                logical_ptr_size_align: (4, 4),
154                ..spirt::qptr::LayoutConfig::VULKAN_SCALAR_LAYOUT
155            };
156
157            let profiler = before_pass("qptr::lower_from_spv_ptrs", module);
158            spirt::passes::qptr::lower_from_spv_ptrs(module, layout_config);
159            after_pass(Some(module), profiler);
160
161            let profiler = before_pass("qptr::analyze_uses", module);
162            spirt::passes::qptr::analyze_uses(module, layout_config);
163            after_pass(Some(module), profiler);
164
165            let profiler = before_pass("qptr::lift_to_spv_ptrs", module);
166            spirt::passes::qptr::lift_to_spv_ptrs(module, layout_config);
167            after_pass(Some(module), profiler);
168
169            continue;
170        }
171
172        let (full_name, pass_fn): (_, fn(_, &mut _)) = match name {
173            "reduce" => ("spirt_passes::reduce", reduce::reduce_in_func),
174            "fuse_selects" => (
175                "spirt_passes::fuse_selects",
176                fuse_selects::fuse_selects_in_func,
177            ),
178            _ => panic!("unknown `--spirt-passes={name}`"),
179        };
180
181        let profiler = before_pass(full_name, module);
182        for &func in &all_funcs {
183            if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def {
184                pass_fn(cx, func_def_body);
185
186                // FIXME(eddyb) avoid doing this except where changes occurred.
187                remove_unused_values_in_func(cx, func_def_body);
188            }
189        }
190        after_pass(Some(module), profiler);
191    }
192}
193
194// FIXME(eddyb) this is just copy-pasted from `spirt` and should be reusable.
195struct ReachableUseCollector<'a> {
196    cx: &'a Context,
197    module: &'a Module,
198
199    // FIXME(eddyb) build some automation to avoid ever repeating these.
200    seen_types: FxIndexSet<Type>,
201    seen_consts: FxIndexSet<Const>,
202    seen_data_inst_forms: FxIndexSet<DataInstForm>,
203    seen_global_vars: FxIndexSet<GlobalVar>,
204    seen_funcs: FxIndexSet<Func>,
205}
206
207impl Visitor<'_> for ReachableUseCollector<'_> {
208    // FIXME(eddyb) build some automation to avoid ever repeating these.
209    fn visit_attr_set_use(&mut self, _attrs: AttrSet) {}
210    fn visit_type_use(&mut self, ty: Type) {
211        if self.seen_types.insert(ty) {
212            self.visit_type_def(&self.cx[ty]);
213        }
214    }
215    fn visit_const_use(&mut self, ct: Const) {
216        if self.seen_consts.insert(ct) {
217            self.visit_const_def(&self.cx[ct]);
218        }
219    }
220    fn visit_data_inst_form_use(&mut self, data_inst_form: DataInstForm) {
221        if self.seen_data_inst_forms.insert(data_inst_form) {
222            self.visit_data_inst_form_def(&self.cx[data_inst_form]);
223        }
224    }
225
226    fn visit_global_var_use(&mut self, gv: GlobalVar) {
227        if self.seen_global_vars.insert(gv) {
228            self.visit_global_var_decl(&self.module.global_vars[gv]);
229        }
230    }
231    fn visit_func_use(&mut self, func: Func) {
232        if self.seen_funcs.insert(func) {
233            self.visit_func_decl(&self.module.funcs[func]);
234        }
235    }
236}
237
238// FIXME(eddyb) maybe this should be provided by `spirt::visit`.
239struct VisitAllControlRegionsAndNodes<S, VCR, VCN> {
240    state: S,
241    visit_control_region: VCR,
242    visit_control_node: VCN,
243}
244const _: () = {
245    use spirt::{func_at::*, visit::*, *};
246
247    impl<
248        'a,
249        S,
250        VCR: FnMut(&mut S, FuncAt<'a, ControlRegion>),
251        VCN: FnMut(&mut S, FuncAt<'a, ControlNode>),
252    > Visitor<'a> for VisitAllControlRegionsAndNodes<S, VCR, VCN>
253    {
254        // FIXME(eddyb) this is excessive, maybe different kinds of
255        // visitors should exist for module-level and func-level?
256        fn visit_attr_set_use(&mut self, _: AttrSet) {}
257        fn visit_type_use(&mut self, _: Type) {}
258        fn visit_const_use(&mut self, _: Const) {}
259        fn visit_data_inst_form_use(&mut self, _: DataInstForm) {}
260        fn visit_global_var_use(&mut self, _: GlobalVar) {}
261        fn visit_func_use(&mut self, _: Func) {}
262
263        fn visit_control_region_def(&mut self, func_at_control_region: FuncAt<'a, ControlRegion>) {
264            (self.visit_control_region)(&mut self.state, func_at_control_region);
265            func_at_control_region.inner_visit_with(self);
266        }
267        fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'a, ControlNode>) {
268            (self.visit_control_node)(&mut self.state, func_at_control_node);
269            func_at_control_node.inner_visit_with(self);
270        }
271    }
272};
273
274// FIXME(eddyb) maybe this should be provided by `spirt::transform`.
275struct ReplaceValueWith<F>(F);
276const _: () = {
277    use spirt::{transform::*, *};
278
279    impl<F: FnMut(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
280        fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
281            self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
282        }
283    }
284};
285
286/// Clean up after a pass by removing unused (pure) `Value` definitions from
287/// a function body (both `DataInst`s and `ControlRegion` inputs/outputs).
288//
289// FIXME(eddyb) should this be a dedicated pass?
290fn remove_unused_values_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
291    // Avoid having to support unstructured control-flow.
292    if func_def_body.unstructured_cfg.is_some() {
293        return;
294    }
295
296    let wk = &SpvSpecWithExtras::get().well_known;
297
298    struct Propagator {
299        func_body_region: ControlRegion,
300
301        // FIXME(eddyb) maybe this kind of "parent map" should be provided by SPIR-T?
302        loop_body_to_loop: EntityOrientedDenseMap<ControlRegion, ControlNode>,
303
304        // FIXME(eddyb) entity-keyed dense sets might be better for performance,
305        // but would require separate sets/maps for separate `Value` cases.
306        used: FxHashSet<Value>,
307
308        queue: VecDeque<Value>,
309    }
310    impl Propagator {
311        fn mark_used(&mut self, v: Value) {
312            if let Value::Const(_) = v {
313                return;
314            }
315            if let Value::ControlRegionInput {
316                region,
317                input_idx: _,
318            } = v
319                && region == self.func_body_region
320            {
321                return;
322            }
323            if self.used.insert(v) {
324                self.queue.push_back(v);
325            }
326        }
327        fn propagate_used(&mut self, func: FuncAt<'_, ()>) {
328            while let Some(v) = self.queue.pop_front() {
329                match v {
330                    Value::Const(_) => unreachable!(),
331                    Value::ControlRegionInput { region, input_idx } => {
332                        let loop_node = self.loop_body_to_loop[region];
333                        let initial_inputs = match &func.at(loop_node).def().kind {
334                            ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
335                            // NOTE(eddyb) only `Loop`s' bodies can have inputs right now.
336                            _ => unreachable!(),
337                        };
338                        self.mark_used(initial_inputs[input_idx as usize]);
339                        self.mark_used(func.at(region).def().outputs[input_idx as usize]);
340                    }
341                    Value::ControlNodeOutput {
342                        control_node,
343                        output_idx,
344                    } => {
345                        let cases = match &func.at(control_node).def().kind {
346                            ControlNodeKind::Select { cases, .. } => cases,
347                            // NOTE(eddyb) only `Select`s can have outputs right now.
348                            _ => unreachable!(),
349                        };
350                        for &case in cases {
351                            self.mark_used(func.at(case).def().outputs[output_idx as usize]);
352                        }
353                    }
354                    Value::DataInstOutput(inst) => {
355                        for &input in &func.at(inst).def().inputs {
356                            self.mark_used(input);
357                        }
358                    }
359                }
360            }
361        }
362    }
363
364    // HACK(eddyb) it's simpler to first ensure `loop_body_to_loop` is computed,
365    // just to allow the later unordered propagation to always work.
366    let propagator = {
367        let mut visitor = VisitAllControlRegionsAndNodes {
368            state: Propagator {
369                func_body_region: func_def_body.body,
370                loop_body_to_loop: Default::default(),
371                used: Default::default(),
372                queue: Default::default(),
373            },
374            visit_control_region: |_: &mut _, _| {},
375            visit_control_node:
376                |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
377                    if let ControlNodeKind::Loop { body, .. } = func_at_control_node.def().kind {
378                        propagator
379                            .loop_body_to_loop
380                            .insert(body, func_at_control_node.position);
381                    }
382                },
383        };
384        func_def_body.inner_visit_with(&mut visitor);
385        visitor.state
386    };
387
388    // HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
389    let mut all_control_nodes = vec![];
390
391    let used_values = {
392        let mut visitor = VisitAllControlRegionsAndNodes {
393            state: propagator,
394            visit_control_region: |_: &mut _, _| {},
395            visit_control_node:
396                |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
397                    all_control_nodes.push(func_at_control_node.position);
398
399                    let mut mark_used_and_propagate = |v| {
400                        propagator.mark_used(v);
401                        propagator.propagate_used(func_at_control_node.at(()));
402                    };
403                    match &func_at_control_node.def().kind {
404                        &ControlNodeKind::Block { insts } => {
405                            for func_at_inst in func_at_control_node.at(insts) {
406                                // Ignore pure instructions (i.e. they're only used
407                                // if their output value is used, from somewhere else).
408                                if let DataInstKind::SpvInst(spv_inst) =
409                                    &cx[func_at_inst.def().form].kind
410                                {
411                                    // HACK(eddyb) small selection relevant for now,
412                                    // but should be extended using e.g. a bitset.
413                                    if [wk.OpNop, wk.OpCompositeInsert].contains(&spv_inst.opcode) {
414                                        continue;
415                                    }
416                                }
417                                mark_used_and_propagate(Value::DataInstOutput(
418                                    func_at_inst.position,
419                                ));
420                            }
421                        }
422
423                        &ControlNodeKind::Select { scrutinee: v, .. }
424                        | &ControlNodeKind::Loop {
425                            repeat_condition: v,
426                            ..
427                        } => mark_used_and_propagate(v),
428
429                        ControlNodeKind::ExitInvocation {
430                            kind: spirt::cfg::ExitInvocationKind::SpvInst(_),
431                            inputs,
432                        } => {
433                            for &v in inputs {
434                                mark_used_and_propagate(v);
435                            }
436                        }
437                    }
438                },
439        };
440        func_def_body.inner_visit_with(&mut visitor);
441
442        let mut propagator = visitor.state;
443        for &v in &func_def_body.at_body().def().outputs {
444            propagator.mark_used(v);
445            propagator.propagate_used(func_def_body.at(()));
446        }
447
448        assert!(propagator.queue.is_empty());
449        propagator.used
450    };
451
452    // FIXME(eddyb) entity-keyed dense maps might be better for performance,
453    // but would require separate maps for separate `Value` cases.
454    let mut value_replacements = FxHashMap::default();
455
456    // Remove anything that didn't end up marked as used (directly or indirectly).
457    for control_node in all_control_nodes {
458        let control_node_def = func_def_body.at(control_node).def();
459        match &control_node_def.kind {
460            &ControlNodeKind::Block { insts } => {
461                let mut all_nops = true;
462                let mut func_at_inst_iter = func_def_body.at_mut(insts).into_iter();
463                while let Some(mut func_at_inst) = func_at_inst_iter.next() {
464                    if let DataInstKind::SpvInst(spv_inst) =
465                        &cx[func_at_inst.reborrow().def().form].kind
466                        && spv_inst.opcode == wk.OpNop
467                    {
468                        continue;
469                    }
470                    if !used_values.contains(&Value::DataInstOutput(func_at_inst.position)) {
471                        // Replace the removed `DataInstDef` itself with `OpNop`,
472                        // removing the ability to use its "name" as a value.
473                        //
474                        // FIXME(eddyb) cache the interned `OpNop`.
475                        *func_at_inst.def() = DataInstDef {
476                            attrs: Default::default(),
477                            form: cx.intern(DataInstFormDef {
478                                kind: DataInstKind::SpvInst(wk.OpNop.into()),
479                                output_type: None,
480                            }),
481                            inputs: iter::empty().collect(),
482                        };
483                        continue;
484                    }
485                    all_nops = false;
486                }
487                // HACK(eddyb) because we can't remove list elements yet, we
488                // instead replace blocks of `OpNop`s with empty ones.
489                if all_nops {
490                    func_def_body.at_mut(control_node).def().kind = ControlNodeKind::Block {
491                        insts: Default::default(),
492                    };
493                }
494            }
495
496            ControlNodeKind::Select { cases, .. } => {
497                // FIXME(eddyb) remove this cloning.
498                let cases = cases.clone();
499
500                let mut new_idx = 0;
501                for original_idx in 0..control_node_def.outputs.len() {
502                    let original_output = Value::ControlNodeOutput {
503                        control_node,
504                        output_idx: original_idx as u32,
505                    };
506
507                    if !used_values.contains(&original_output) {
508                        // Remove the output definition and corresponding value from all cases.
509                        func_def_body
510                            .at_mut(control_node)
511                            .def()
512                            .outputs
513                            .remove(new_idx);
514                        for &case in &cases {
515                            func_def_body.at_mut(case).def().outputs.remove(new_idx);
516                        }
517                        continue;
518                    }
519
520                    // Record remappings for any still-used outputs that got "shifted over".
521                    if original_idx != new_idx {
522                        let new_output = Value::ControlNodeOutput {
523                            control_node,
524                            output_idx: new_idx as u32,
525                        };
526                        value_replacements.insert(original_output, new_output);
527                    }
528                    new_idx += 1;
529                }
530            }
531
532            ControlNodeKind::Loop {
533                body,
534                initial_inputs,
535                ..
536            } => {
537                let body = *body;
538
539                let mut new_idx = 0;
540                for original_idx in 0..initial_inputs.len() {
541                    let original_input = Value::ControlRegionInput {
542                        region: body,
543                        input_idx: original_idx as u32,
544                    };
545
546                    if !used_values.contains(&original_input) {
547                        // Remove the input definition and corresponding values.
548                        match &mut func_def_body.at_mut(control_node).def().kind {
549                            ControlNodeKind::Loop { initial_inputs, .. } => {
550                                initial_inputs.remove(new_idx);
551                            }
552                            _ => unreachable!(),
553                        }
554                        let body_def = func_def_body.at_mut(body).def();
555                        body_def.inputs.remove(new_idx);
556                        body_def.outputs.remove(new_idx);
557                        continue;
558                    }
559
560                    // Record remappings for any still-used inputs that got "shifted over".
561                    if original_idx != new_idx {
562                        let new_input = Value::ControlRegionInput {
563                            region: body,
564                            input_idx: new_idx as u32,
565                        };
566                        value_replacements.insert(original_input, new_input);
567                    }
568                    new_idx += 1;
569                }
570            }
571
572            ControlNodeKind::ExitInvocation { .. } => {}
573        }
574    }
575
576    if !value_replacements.is_empty() {
577        func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
578            Value::Const(_) => None,
579            _ => value_replacements.get(&v).copied(),
580        }));
581    }
582}