rustc_codegen_spirv/linker/spirt_passes/
fuse_selects.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use spirt::func_at::FuncAt;
use spirt::transform::InnerInPlaceTransform;
use spirt::visit::InnerVisit;
use spirt::{Context, ControlNodeKind, ControlRegion, FuncDefBody, SelectionKind, Value};
use std::mem;

use super::{ReplaceValueWith, VisitAllControlRegionsAndNodes};

/// Combine consecutive `Select`s in `func_def_body`.
pub(crate) fn fuse_selects_in_func(_cx: &Context, func_def_body: &mut FuncDefBody) {
    // HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
    let mut all_regions = vec![];

    func_def_body.inner_visit_with(&mut VisitAllControlRegionsAndNodes {
        state: (),
        visit_control_region: |_: &mut (), func_at_control_region: FuncAt<'_, ControlRegion>| {
            all_regions.push(func_at_control_region.position);
        },
        visit_control_node: |_: &mut (), _| {},
    });

    for region in all_regions {
        let mut func_at_children_iter = func_def_body.at_mut(region).at_children().into_iter();
        while let Some(func_at_child) = func_at_children_iter.next() {
            let base_control_node = func_at_child.position;
            if let ControlNodeKind::Select {
                kind: SelectionKind::BoolCond,
                scrutinee,
                cases,
            } = &func_at_child.def().kind
            {
                let &base_cond = scrutinee;
                let base_cases = cases.clone();

                // Scan ahead for candidate `Select`s (with the same condition).
                let mut fusion_candidate_iter = func_at_children_iter.reborrow();
                while let Some(func_at_fusion_candidate) = fusion_candidate_iter.next() {
                    let fusion_candidate = func_at_fusion_candidate.position;
                    let mut func = func_at_fusion_candidate.at(());
                    let fusion_candidate_def = func.reborrow().at(fusion_candidate).def();
                    match &fusion_candidate_def.kind {
                        // HACK(eddyb) ignore empty blocks (created by
                        // e.g. `remove_unused_values_in_func`).
                        ControlNodeKind::Block { insts } if insts.is_empty() => {}

                        ControlNodeKind::Select {
                            kind: SelectionKind::BoolCond,
                            scrutinee: fusion_candidate_cond,
                            cases: fusion_candidate_cases,
                        } if *fusion_candidate_cond == base_cond => {
                            // FIXME(eddyb) handle outputs from the second `Select`.
                            if !fusion_candidate_def.outputs.is_empty() {
                                break;
                            }

                            let cases_to_fuse = fusion_candidate_cases.clone();

                            // Concatenate the `Select`s' respective cases
                            // ("then" with "then", "else" with "else", etc.).
                            for (&base_case, &case_to_fuse) in base_cases.iter().zip(&cases_to_fuse)
                            {
                                let children_of_case_to_fuse =
                                    mem::take(&mut func.reborrow().at(case_to_fuse).def().children);

                                // Replace uses of the outputs of the first `Select`,
                                // in the second one's case, with the specific values
                                // (e.g. `let y = if c { x } ...; if c { f(y) }`
                                // has to become `let y = if c { f(x); x } ...`).
                                //
                                // FIXME(eddyb) avoid cloning here.
                                let outputs_of_base_case =
                                    func.reborrow().at(base_case).def().outputs.clone();
                                func.reborrow()
                                    .at(children_of_case_to_fuse)
                                    .into_iter()
                                    .inner_in_place_transform_with(&mut ReplaceValueWith(
                                        |v| match v {
                                            Value::ControlNodeOutput {
                                                control_node,
                                                output_idx,
                                            } if control_node == base_control_node => {
                                                Some(outputs_of_base_case[output_idx as usize])
                                            }

                                            _ => None,
                                        },
                                    ));

                                func.control_regions[base_case]
                                    .children
                                    .append(children_of_case_to_fuse, func.control_nodes);
                            }

                            // HACK(eddyb) because we can't remove list elements yet,
                            // we instead replace the `Select` with an empty `Block`.
                            func.reborrow().at(fusion_candidate).def().kind =
                                ControlNodeKind::Block {
                                    insts: Default::default(),
                                };
                        }

                        _ => break,
                    }
                }
            }
        }
    }
}