rustc_codegen_spirv/linker/spirt_passes/fuse_selects.rs
1use spirt::func_at::FuncAt;
2use spirt::transform::InnerInPlaceTransform;
3use spirt::visit::InnerVisit;
4use spirt::{Context, ControlNodeKind, ControlRegion, FuncDefBody, SelectionKind, Value};
5use std::mem;
6
7use super::{ReplaceValueWith, VisitAllControlRegionsAndNodes};
8
9/// Combine consecutive `Select`s in `func_def_body`.
10pub(crate) fn fuse_selects_in_func(_cx: &Context, func_def_body: &mut FuncDefBody) {
11 // HACK(eddyb) this kind of random-access is easier than using `spirt::transform`.
12 let mut all_regions = vec![];
13
14 func_def_body.inner_visit_with(&mut VisitAllControlRegionsAndNodes {
15 state: (),
16 visit_control_region: |_: &mut (), func_at_control_region: FuncAt<'_, ControlRegion>| {
17 all_regions.push(func_at_control_region.position);
18 },
19 visit_control_node: |_: &mut (), _| {},
20 });
21
22 for region in all_regions {
23 let mut func_at_children_iter = func_def_body.at_mut(region).at_children().into_iter();
24 while let Some(func_at_child) = func_at_children_iter.next() {
25 let base_control_node = func_at_child.position;
26 if let ControlNodeKind::Select {
27 kind: SelectionKind::BoolCond,
28 scrutinee,
29 cases,
30 } = &func_at_child.def().kind
31 {
32 let &base_cond = scrutinee;
33 let base_cases = cases.clone();
34
35 // Scan ahead for candidate `Select`s (with the same condition).
36 let mut fusion_candidate_iter = func_at_children_iter.reborrow();
37 while let Some(func_at_fusion_candidate) = fusion_candidate_iter.next() {
38 let fusion_candidate = func_at_fusion_candidate.position;
39 let mut func = func_at_fusion_candidate.at(());
40 let fusion_candidate_def = func.reborrow().at(fusion_candidate).def();
41 match &fusion_candidate_def.kind {
42 // HACK(eddyb) ignore empty blocks (created by
43 // e.g. `remove_unused_values_in_func`).
44 ControlNodeKind::Block { insts } if insts.is_empty() => {}
45
46 ControlNodeKind::Select {
47 kind: SelectionKind::BoolCond,
48 scrutinee: fusion_candidate_cond,
49 cases: fusion_candidate_cases,
50 } if *fusion_candidate_cond == base_cond => {
51 // FIXME(eddyb) handle outputs from the second `Select`.
52 if !fusion_candidate_def.outputs.is_empty() {
53 break;
54 }
55
56 let cases_to_fuse = fusion_candidate_cases.clone();
57
58 // Concatenate the `Select`s' respective cases
59 // ("then" with "then", "else" with "else", etc.).
60 for (&base_case, &case_to_fuse) in base_cases.iter().zip(&cases_to_fuse)
61 {
62 let children_of_case_to_fuse =
63 mem::take(&mut func.reborrow().at(case_to_fuse).def().children);
64
65 // Replace uses of the outputs of the first `Select`,
66 // in the second one's case, with the specific values
67 // (e.g. `let y = if c { x } ...; if c { f(y) }`
68 // has to become `let y = if c { f(x); x } ...`).
69 //
70 // FIXME(eddyb) avoid cloning here.
71 let outputs_of_base_case =
72 func.reborrow().at(base_case).def().outputs.clone();
73 func.reborrow()
74 .at(children_of_case_to_fuse)
75 .into_iter()
76 .inner_in_place_transform_with(&mut ReplaceValueWith(
77 |v| match v {
78 Value::ControlNodeOutput {
79 control_node,
80 output_idx,
81 } if control_node == base_control_node => {
82 Some(outputs_of_base_case[output_idx as usize])
83 }
84
85 _ => None,
86 },
87 ));
88
89 func.control_regions[base_case]
90 .children
91 .append(children_of_case_to_fuse, func.control_nodes);
92 }
93
94 // HACK(eddyb) because we can't remove list elements yet,
95 // we instead replace the `Select` with an empty `Block`.
96 func.reborrow().at(fusion_candidate).def().kind =
97 ControlNodeKind::Block {
98 insts: Default::default(),
99 };
100 }
101
102 _ => break,
103 }
104 }
105 }
106 }
107 }
108}