rustc_codegen_spirv/linker/spirt_passes/
mod.rs

1//! SPIR-T pass infrastructure and supporting utilities.
2
3pub(crate) mod controlflow;
4pub(crate) mod debuginfo;
5pub(crate) mod diagnostics;
6pub(crate) mod explicit_layout;
7mod fuse_selects;
8mod reduce;
9pub(crate) mod validate;
10
11use lazy_static::lazy_static;
12use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
13use spirt::func_at::FuncAt;
14use spirt::transform::InnerInPlaceTransform;
15use spirt::visit::{InnerVisit, Visitor};
16use spirt::{
17    AttrSet, Const, Context, ControlNode, ControlNodeKind, ControlRegion, DataInstDef,
18    DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityOrientedDenseMap, Func,
19    FuncDefBody, GlobalVar, Module, Type, Value, spv,
20};
21use std::collections::VecDeque;
22use std::iter;
23
24// HACK(eddyb) `spv::spec::Spec` with extra `WellKnown`s (that should be upstreamed).
25macro_rules! def_spv_spec_with_extra_well_known {
26    ($($group:ident: $ty:ty = [$($entry:ident),+ $(,)?]),+ $(,)?) => {
27        struct SpvSpecWithExtras {
28            __base_spec: &'static spv::spec::Spec,
29
30            well_known: SpvWellKnownWithExtras,
31        }
32
33        #[allow(non_snake_case)]
34        pub struct SpvWellKnownWithExtras {
35            __base_well_known: &'static spv::spec::WellKnown,
36
37            $($(pub $entry: $ty,)+)+
38        }
39
40        impl std::ops::Deref for SpvSpecWithExtras {
41            type Target = spv::spec::Spec;
42            fn deref(&self) -> &Self::Target {
43                self.__base_spec
44            }
45        }
46
47        impl std::ops::Deref for SpvWellKnownWithExtras {
48            type Target = spv::spec::WellKnown;
49            fn deref(&self) -> &Self::Target {
50                self.__base_well_known
51            }
52        }
53
54        impl SpvSpecWithExtras {
55            #[inline(always)]
56            #[must_use]
57            pub fn get() -> &'static SpvSpecWithExtras {
58                lazy_static! {
59                    static ref SPEC: SpvSpecWithExtras = {
60                        #[allow(non_camel_case_types)]
61                        struct PerWellKnownGroup<$($group),+> {
62                            $($group: $group),+
63                        }
64
65                        let spv_spec = spv::spec::Spec::get();
66                        let wk = &spv_spec.well_known;
67
68                        let [decorations, storage_classes] = [wk.Decoration, wk.StorageClass].map(|kind| match kind.def() {
69                            spv::spec::OperandKindDef::ValueEnum { variants } => variants,
70                            _ => unreachable!(),
71                        });
72
73                        let lookup_fns = PerWellKnownGroup {
74                            opcode: |name| spv_spec.instructions.lookup(name).unwrap(),
75                            operand_kind: |name| spv_spec.operand_kinds.lookup(name).unwrap(),
76                            decoration: |name| decorations.lookup(name).unwrap().into(),
77                            storage_class: |name| storage_classes.lookup(name).unwrap().into(),
78                        };
79
80                        SpvSpecWithExtras {
81                            __base_spec: spv_spec,
82
83                            well_known: SpvWellKnownWithExtras {
84                                __base_well_known: &spv_spec.well_known,
85
86                                $($($entry: (lookup_fns.$group)(stringify!($entry)),)+)+
87                            },
88                        }
89                    };
90                }
91                &SPEC
92            }
93        }
94    };
95}
96def_spv_spec_with_extra_well_known! {
97    opcode: spv::spec::Opcode = [
98        OpTypeVoid,
99
100        OpConstantComposite,
101
102        OpBitcast,
103        OpCompositeInsert,
104        OpCompositeExtract,
105        OpCompositeConstruct,
106
107        OpCopyMemory,
108    ],
109    operand_kind: spv::spec::OperandKind = [
110        Capability,
111        ExecutionModel,
112        ImageFormat,
113        MemoryAccess,
114    ],
115    decoration: u32 = [
116        UserTypeGOOGLE,
117        MatrixStride,
118    ],
119    storage_class: u32 = [
120        PushConstant,
121        Uniform,
122        StorageBuffer,
123        PhysicalStorageBuffer,
124    ],
125}
126
127/// Run intra-function passes on all `Func` definitions in the `Module`.
128//
129// FIXME(eddyb) introduce a proper "pass manager".
130// FIXME(eddyb) why does this focus on functions, it could just be module passes??
131pub(super) fn run_func_passes<P>(
132    module: &mut Module,
133    passes: &[impl AsRef<str>],
134    // FIXME(eddyb) this is a very poor approximation of a "profiler" abstraction.
135    mut before_pass: impl FnMut(&'static str, &Module) -> P,
136    mut after_pass: impl FnMut(Option<&Module>, P),
137) {
138    let cx = &module.cx();
139
140    // FIXME(eddyb) reuse this collection work in some kind of "pass manager".
141    let all_funcs = {
142        let mut collector = ReachableUseCollector {
143            cx,
144            module,
145
146            seen_types: FxIndexSet::default(),
147            seen_consts: FxIndexSet::default(),
148            seen_data_inst_forms: FxIndexSet::default(),
149            seen_global_vars: FxIndexSet::default(),
150            seen_funcs: FxIndexSet::default(),
151        };
152        for (export_key, &exportee) in &module.exports {
153            export_key.inner_visit_with(&mut collector);
154            exportee.inner_visit_with(&mut collector);
155        }
156        collector.seen_funcs
157    };
158
159    for name in passes {
160        let name = name.as_ref();
161
162        // HACK(eddyb) not really a function pass.
163        if name == "qptr" {
164            let layout_config = &spirt::qptr::LayoutConfig {
165                abstract_bool_size_align: (1, 1),
166                logical_ptr_size_align: (4, 4),
167                ..spirt::qptr::LayoutConfig::VULKAN_SCALAR_LAYOUT
168            };
169
170            let profiler = before_pass("qptr::lower_from_spv_ptrs", module);
171            spirt::passes::qptr::lower_from_spv_ptrs(module, layout_config);
172            after_pass(Some(module), profiler);
173
174            let profiler = before_pass("qptr::analyze_uses", module);
175            spirt::passes::qptr::analyze_uses(module, layout_config);
176            after_pass(Some(module), profiler);
177
178            let profiler = before_pass("qptr::lift_to_spv_ptrs", module);
179            spirt::passes::qptr::lift_to_spv_ptrs(module, layout_config);
180            after_pass(Some(module), profiler);
181
182            continue;
183        }
184
185        let (full_name, pass_fn): (_, fn(_, &mut _)) = match name {
186            "reduce" => ("spirt_passes::reduce", reduce::reduce_in_func),
187            "fuse_selects" => (
188                "spirt_passes::fuse_selects",
189                fuse_selects::fuse_selects_in_func,
190            ),
191            _ => panic!("unknown `--spirt-passes={name}`"),
192        };
193
194        let profiler = before_pass(full_name, module);
195        for &func in &all_funcs {
196            if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def {
197                pass_fn(cx, func_def_body);
198
199                // FIXME(eddyb) avoid doing this except where changes occurred.
200                remove_unused_values_in_func(cx, func_def_body);
201            }
202        }
203        after_pass(Some(module), profiler);
204    }
205}
206
207// FIXME(eddyb) this is just copy-pasted from `spirt` and should be reusable.
208struct ReachableUseCollector<'a> {
209    cx: &'a Context,
210    module: &'a Module,
211
212    // FIXME(eddyb) build some automation to avoid ever repeating these.
213    seen_types: FxIndexSet<Type>,
214    seen_consts: FxIndexSet<Const>,
215    seen_data_inst_forms: FxIndexSet<DataInstForm>,
216    seen_global_vars: FxIndexSet<GlobalVar>,
217    seen_funcs: FxIndexSet<Func>,
218}
219
220impl Visitor<'_> for ReachableUseCollector<'_> {
221    // FIXME(eddyb) build some automation to avoid ever repeating these.
222    fn visit_attr_set_use(&mut self, _attrs: AttrSet) {}
223    fn visit_type_use(&mut self, ty: Type) {
224        if self.seen_types.insert(ty) {
225            self.visit_type_def(&self.cx[ty]);
226        }
227    }
228    fn visit_const_use(&mut self, ct: Const) {
229        if self.seen_consts.insert(ct) {
230            self.visit_const_def(&self.cx[ct]);
231        }
232    }
233    fn visit_data_inst_form_use(&mut self, data_inst_form: DataInstForm) {
234        if self.seen_data_inst_forms.insert(data_inst_form) {
235            self.visit_data_inst_form_def(&self.cx[data_inst_form]);
236        }
237    }
238
239    fn visit_global_var_use(&mut self, gv: GlobalVar) {
240        if self.seen_global_vars.insert(gv) {
241            self.visit_global_var_decl(&self.module.global_vars[gv]);
242        }
243    }
244    fn visit_func_use(&mut self, func: Func) {
245        if self.seen_funcs.insert(func) {
246            self.visit_func_decl(&self.module.funcs[func]);
247        }
248    }
249}
250
251// FIXME(eddyb) maybe this should be provided by `spirt::visit`.
252struct VisitAllControlRegionsAndNodes<S, VCR, VCN> {
253    state: S,
254    visit_control_region: VCR,
255    visit_control_node: VCN,
256}
257const _: () = {
258    use spirt::{func_at::*, visit::*, *};
259
260    impl<
261        'a,
262        S,
263        VCR: FnMut(&mut S, FuncAt<'a, ControlRegion>),
264        VCN: FnMut(&mut S, FuncAt<'a, ControlNode>),
265    > Visitor<'a> for VisitAllControlRegionsAndNodes<S, VCR, VCN>
266    {
267        // FIXME(eddyb) this is excessive, maybe different kinds of
268        // visitors should exist for module-level and func-level?
269        fn visit_attr_set_use(&mut self, _: AttrSet) {}
270        fn visit_type_use(&mut self, _: Type) {}
271        fn visit_const_use(&mut self, _: Const) {}
272        fn visit_data_inst_form_use(&mut self, _: DataInstForm) {}
273        fn visit_global_var_use(&mut self, _: GlobalVar) {}
274        fn visit_func_use(&mut self, _: Func) {}
275
276        fn visit_control_region_def(&mut self, func_at_control_region: FuncAt<'a, ControlRegion>) {
277            (self.visit_control_region)(&mut self.state, func_at_control_region);
278            func_at_control_region.inner_visit_with(self);
279        }
280        fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'a, ControlNode>) {
281            (self.visit_control_node)(&mut self.state, func_at_control_node);
282            func_at_control_node.inner_visit_with(self);
283        }
284    }
285};
286
287// FIXME(eddyb) maybe this should be provided by `spirt::transform`.
288struct ReplaceValueWith<F>(F);
289const _: () = {
290    use spirt::{transform::*, *};
291
292    impl<F: FnMut(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
293        fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
294            self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
295        }
296    }
297};
298
299/// Clean up after a pass by removing unused (pure) `Value` definitions from
300/// a function body (both `DataInst`s and `ControlRegion` inputs/outputs).
301//
302// FIXME(eddyb) should this be a dedicated pass?
303fn remove_unused_values_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
304    // Avoid having to support unstructured control-flow.
305    if func_def_body.unstructured_cfg.is_some() {
306        return;
307    }
308
309    let wk = &SpvSpecWithExtras::get().well_known;
310
311    struct Propagator {
312        func_body_region: ControlRegion,
313
314        // FIXME(eddyb) maybe this kind of "parent map" should be provided by SPIR-T?
315        loop_body_to_loop: EntityOrientedDenseMap<ControlRegion, ControlNode>,
316
317        // FIXME(eddyb) entity-keyed dense sets might be better for performance,
318        // but would require separate sets/maps for separate `Value` cases.
319        used: FxHashSet<Value>,
320
321        queue: VecDeque<Value>,
322    }
323    impl Propagator {
324        fn mark_used(&mut self, v: Value) {
325            if let Value::Const(_) = v {
326                return;
327            }
328            if let Value::ControlRegionInput {
329                region,
330                input_idx: _,
331            } = v
332                && region == self.func_body_region
333            {
334                return;
335            }
336            if self.used.insert(v) {
337                self.queue.push_back(v);
338            }
339        }
340        fn propagate_used(&mut self, func: FuncAt<'_, ()>) {
341            while let Some(v) = self.queue.pop_front() {
342                match v {
343                    Value::Const(_) => unreachable!(),
344                    Value::ControlRegionInput { region, input_idx } => {
345                        let loop_node = self.loop_body_to_loop[region];
346                        let initial_inputs = match &func.at(loop_node).def().kind {
347                            ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
348                            // NOTE(eddyb) only `Loop`s' bodies can have inputs right now.
349                            _ => unreachable!(),
350                        };
351                        self.mark_used(initial_inputs[input_idx as usize]);
352                        self.mark_used(func.at(region).def().outputs[input_idx as usize]);
353                    }
354                    Value::ControlNodeOutput {
355                        control_node,
356                        output_idx,
357                    } => {
358                        let cases = match &func.at(control_node).def().kind {
359                            ControlNodeKind::Select { cases, .. } => cases,
360                            // NOTE(eddyb) only `Select`s can have outputs right now.
361                            _ => unreachable!(),
362                        };
363                        for &case in cases {
364                            self.mark_used(func.at(case).def().outputs[output_idx as usize]);
365                        }
366                    }
367                    Value::DataInstOutput(inst) => {
368                        for &input in &func.at(inst).def().inputs {
369                            self.mark_used(input);
370                        }
371                    }
372                }
373            }
374        }
375    }
376
377    // HACK(eddyb) it's simpler to first ensure `loop_body_to_loop` is computed,
378    // just to allow the later unordered propagation to always work.
379    let propagator = {
380        let mut visitor = VisitAllControlRegionsAndNodes {
381            state: Propagator {
382                func_body_region: func_def_body.body,
383                loop_body_to_loop: Default::default(),
384                used: Default::default(),
385                queue: Default::default(),
386            },
387            visit_control_region: |_: &mut _, _| {},
388            visit_control_node:
389                |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
390                    if let ControlNodeKind::Loop { body, .. } = func_at_control_node.def().kind {
391                        propagator
392                            .loop_body_to_loop
393                            .insert(body, func_at_control_node.position);
394                    }
395                },
396        };
397        func_def_body.inner_visit_with(&mut visitor);
398        visitor.state
399    };
400
401    // HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
402    let mut all_control_nodes = vec![];
403
404    let used_values = {
405        let mut visitor = VisitAllControlRegionsAndNodes {
406            state: propagator,
407            visit_control_region: |_: &mut _, _| {},
408            visit_control_node:
409                |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
410                    all_control_nodes.push(func_at_control_node.position);
411
412                    let mut mark_used_and_propagate = |v| {
413                        propagator.mark_used(v);
414                        propagator.propagate_used(func_at_control_node.at(()));
415                    };
416                    match &func_at_control_node.def().kind {
417                        &ControlNodeKind::Block { insts } => {
418                            for func_at_inst in func_at_control_node.at(insts) {
419                                // Ignore pure instructions (i.e. they're only used
420                                // if their output value is used, from somewhere else).
421                                if let DataInstKind::SpvInst(spv_inst) =
422                                    &cx[func_at_inst.def().form].kind
423                                {
424                                    // HACK(eddyb) small selection relevant for now,
425                                    // but should be extended using e.g. a bitset.
426                                    if [wk.OpNop, wk.OpCompositeInsert].contains(&spv_inst.opcode) {
427                                        continue;
428                                    }
429                                }
430                                mark_used_and_propagate(Value::DataInstOutput(
431                                    func_at_inst.position,
432                                ));
433                            }
434                        }
435
436                        &ControlNodeKind::Select { scrutinee: v, .. }
437                        | &ControlNodeKind::Loop {
438                            repeat_condition: v,
439                            ..
440                        } => mark_used_and_propagate(v),
441
442                        ControlNodeKind::ExitInvocation {
443                            kind: spirt::cfg::ExitInvocationKind::SpvInst(_),
444                            inputs,
445                        } => {
446                            for &v in inputs {
447                                mark_used_and_propagate(v);
448                            }
449                        }
450                    }
451                },
452        };
453        func_def_body.inner_visit_with(&mut visitor);
454
455        let mut propagator = visitor.state;
456        for &v in &func_def_body.at_body().def().outputs {
457            propagator.mark_used(v);
458            propagator.propagate_used(func_def_body.at(()));
459        }
460
461        assert!(propagator.queue.is_empty());
462        propagator.used
463    };
464
465    // FIXME(eddyb) entity-keyed dense maps might be better for performance,
466    // but would require separate maps for separate `Value` cases.
467    let mut value_replacements = FxHashMap::default();
468
469    // Remove anything that didn't end up marked as used (directly or indirectly).
470    for control_node in all_control_nodes {
471        let control_node_def = func_def_body.at(control_node).def();
472        match &control_node_def.kind {
473            &ControlNodeKind::Block { insts } => {
474                let mut all_nops = true;
475                let mut func_at_inst_iter = func_def_body.at_mut(insts).into_iter();
476                while let Some(mut func_at_inst) = func_at_inst_iter.next() {
477                    if let DataInstKind::SpvInst(spv_inst) =
478                        &cx[func_at_inst.reborrow().def().form].kind
479                        && spv_inst.opcode == wk.OpNop
480                    {
481                        continue;
482                    }
483                    if !used_values.contains(&Value::DataInstOutput(func_at_inst.position)) {
484                        // Replace the removed `DataInstDef` itself with `OpNop`,
485                        // removing the ability to use its "name" as a value.
486                        //
487                        // FIXME(eddyb) cache the interned `OpNop`.
488                        *func_at_inst.def() = DataInstDef {
489                            attrs: Default::default(),
490                            form: cx.intern(DataInstFormDef {
491                                kind: DataInstKind::SpvInst(wk.OpNop.into()),
492                                output_type: None,
493                            }),
494                            inputs: iter::empty().collect(),
495                        };
496                        continue;
497                    }
498                    all_nops = false;
499                }
500                // HACK(eddyb) because we can't remove list elements yet, we
501                // instead replace blocks of `OpNop`s with empty ones.
502                if all_nops {
503                    func_def_body.at_mut(control_node).def().kind = ControlNodeKind::Block {
504                        insts: Default::default(),
505                    };
506                }
507            }
508
509            ControlNodeKind::Select { cases, .. } => {
510                // FIXME(eddyb) remove this cloning.
511                let cases = cases.clone();
512
513                let mut new_idx = 0;
514                for original_idx in 0..control_node_def.outputs.len() {
515                    let original_output = Value::ControlNodeOutput {
516                        control_node,
517                        output_idx: original_idx as u32,
518                    };
519
520                    if !used_values.contains(&original_output) {
521                        // Remove the output definition and corresponding value from all cases.
522                        func_def_body
523                            .at_mut(control_node)
524                            .def()
525                            .outputs
526                            .remove(new_idx);
527                        for &case in &cases {
528                            func_def_body.at_mut(case).def().outputs.remove(new_idx);
529                        }
530                        continue;
531                    }
532
533                    // Record remappings for any still-used outputs that got "shifted over".
534                    if original_idx != new_idx {
535                        let new_output = Value::ControlNodeOutput {
536                            control_node,
537                            output_idx: new_idx as u32,
538                        };
539                        value_replacements.insert(original_output, new_output);
540                    }
541                    new_idx += 1;
542                }
543            }
544
545            ControlNodeKind::Loop {
546                body,
547                initial_inputs,
548                ..
549            } => {
550                let body = *body;
551
552                let mut new_idx = 0;
553                for original_idx in 0..initial_inputs.len() {
554                    let original_input = Value::ControlRegionInput {
555                        region: body,
556                        input_idx: original_idx as u32,
557                    };
558
559                    if !used_values.contains(&original_input) {
560                        // Remove the input definition and corresponding values.
561                        match &mut func_def_body.at_mut(control_node).def().kind {
562                            ControlNodeKind::Loop { initial_inputs, .. } => {
563                                initial_inputs.remove(new_idx);
564                            }
565                            _ => unreachable!(),
566                        }
567                        let body_def = func_def_body.at_mut(body).def();
568                        body_def.inputs.remove(new_idx);
569                        body_def.outputs.remove(new_idx);
570                        continue;
571                    }
572
573                    // Record remappings for any still-used inputs that got "shifted over".
574                    if original_idx != new_idx {
575                        let new_input = Value::ControlRegionInput {
576                            region: body,
577                            input_idx: new_idx as u32,
578                        };
579                        value_replacements.insert(original_input, new_input);
580                    }
581                    new_idx += 1;
582                }
583            }
584
585            ControlNodeKind::ExitInvocation { .. } => {}
586        }
587    }
588
589    if !value_replacements.is_empty() {
590        func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
591            Value::Const(_) => None,
592            _ => value_replacements.get(&v).copied(),
593        }));
594    }
595}