rustc_codegen_spirv/linker/spirt_passes/
fuse_selects.rs

1use spirt::func_at::FuncAt;
2use spirt::transform::InnerInPlaceTransform;
3use spirt::visit::InnerVisit;
4use spirt::{Context, ControlNodeKind, ControlRegion, FuncDefBody, SelectionKind, Value};
5use std::mem;
6
7use super::{ReplaceValueWith, VisitAllControlRegionsAndNodes};
8
9/// Combine consecutive `Select`s in `func_def_body`.
10pub(crate) fn fuse_selects_in_func(_cx: &Context, func_def_body: &mut FuncDefBody) {
11    // HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
12    let mut all_regions = vec![];
13
14    func_def_body.inner_visit_with(&mut VisitAllControlRegionsAndNodes {
15        state: (),
16        visit_control_region: |_: &mut (), func_at_control_region: FuncAt<'_, ControlRegion>| {
17            all_regions.push(func_at_control_region.position);
18        },
19        visit_control_node: |_: &mut (), _| {},
20    });
21
22    for region in all_regions {
23        let mut func_at_children_iter = func_def_body.at_mut(region).at_children().into_iter();
24        while let Some(func_at_child) = func_at_children_iter.next() {
25            let base_control_node = func_at_child.position;
26            if let ControlNodeKind::Select {
27                kind: SelectionKind::BoolCond,
28                scrutinee,
29                cases,
30            } = &func_at_child.def().kind
31            {
32                let &base_cond = scrutinee;
33                let base_cases = cases.clone();
34
35                // Scan ahead for candidate `Select`s (with the same condition).
36                let mut fusion_candidate_iter = func_at_children_iter.reborrow();
37                while let Some(func_at_fusion_candidate) = fusion_candidate_iter.next() {
38                    let fusion_candidate = func_at_fusion_candidate.position;
39                    let mut func = func_at_fusion_candidate.at(());
40                    let fusion_candidate_def = func.reborrow().at(fusion_candidate).def();
41                    match &fusion_candidate_def.kind {
42                        // HACK(eddyb) ignore empty blocks (created by
43                        // e.g. `remove_unused_values_in_func`).
44                        ControlNodeKind::Block { insts } if insts.is_empty() => {}
45
46                        ControlNodeKind::Select {
47                            kind: SelectionKind::BoolCond,
48                            scrutinee: fusion_candidate_cond,
49                            cases: fusion_candidate_cases,
50                        } if *fusion_candidate_cond == base_cond => {
51                            // FIXME(eddyb) handle outputs from the second `Select`.
52                            if !fusion_candidate_def.outputs.is_empty() {
53                                break;
54                            }
55
56                            let cases_to_fuse = fusion_candidate_cases.clone();
57
58                            // Concatenate the `Select`s' respective cases
59                            // ("then" with "then", "else" with "else", etc.).
60                            for (&base_case, &case_to_fuse) in base_cases.iter().zip(&cases_to_fuse)
61                            {
62                                let children_of_case_to_fuse =
63                                    mem::take(&mut func.reborrow().at(case_to_fuse).def().children);
64
65                                // Replace uses of the outputs of the first `Select`,
66                                // in the second one's case, with the specific values
67                                // (e.g. `let y = if c { x } ...; if c { f(y) }`
68                                // has to become `let y = if c { f(x); x } ...`).
69                                //
70                                // FIXME(eddyb) avoid cloning here.
71                                let outputs_of_base_case =
72                                    func.reborrow().at(base_case).def().outputs.clone();
73                                func.reborrow()
74                                    .at(children_of_case_to_fuse)
75                                    .into_iter()
76                                    .inner_in_place_transform_with(&mut ReplaceValueWith(
77                                        |v| match v {
78                                            Value::ControlNodeOutput {
79                                                control_node,
80                                                output_idx,
81                                            } if control_node == base_control_node => {
82                                                Some(outputs_of_base_case[output_idx as usize])
83                                            }
84
85                                            _ => None,
86                                        },
87                                    ));
88
89                                func.control_regions[base_case]
90                                    .children
91                                    .append(children_of_case_to_fuse, func.control_nodes);
92                            }
93
94                            // HACK(eddyb) because we can't remove list elements yet,
95                            // we instead replace the `Select` with an empty `Block`.
96                            func.reborrow().at(fusion_candidate).def().kind =
97                                ControlNodeKind::Block {
98                                    insts: Default::default(),
99                                };
100                        }
101
102                        _ => break,
103                    }
104                }
105            }
106        }
107    }
108}