rustc_codegen_spirv/linker/spirt_passes/
reduce.rs

1use rustc_data_structures::fx::FxHashMap;
2use smallvec::SmallVec;
3use spirt::func_at::{FuncAt, FuncAtMut};
4use spirt::transform::InnerInPlaceTransform;
5use spirt::visit::InnerVisit;
6use spirt::{
7    Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, ControlNodeKind,
8    ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef,
9    DataInstFormDef, DataInstKind, EntityOrientedDenseMap, FuncDefBody, SelectionKind, Type,
10    TypeDef, TypeKind, Value, spv,
11};
12use std::collections::hash_map::Entry;
13use std::{iter, slice};
14
15use super::{ReplaceValueWith, VisitAllControlRegionsAndNodes};
16use std::rc::Rc;
17
18/// Apply "reduction rules" to `func_def_body`, replacing (pure) computations
19/// with one of their inputs or a constant (e.g. `x + 0 => x` or `1 + 2 => 3`),
20/// and at most only adding more `Select` outputs/`Loop` state (where necessary)
21/// but never any new instructions (unlike e.g. LLVM's instcombine).
22pub(crate) fn reduce_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
23    let wk = &super::SpvSpecWithExtras::get().well_known;
24
25    let parent_map = ParentMap::new(func_def_body);
26
27    // FIXME(eddyb) entity-keyed dense maps might be better for performance,
28    // but would require separate maps for separate `Value` cases.
29    let mut value_replacements = FxHashMap::default();
30
31    let mut reduction_cache = FxHashMap::default();
32
33    // HACK(eddyb) this is an annoying workaround for iterator invalidation
34    // (SPIR-T iterators don't cope well with the underlying data changing).
35    //
36    // FIXME(eddyb) replace SPIR-T `FuncAtMut<EntityListIter<T>>` with some
37    // kind of "list cursor", maybe even allowing removal during traversal.
38    let mut reduction_queue = vec![];
39
40    #[derive(Copy, Clone)]
41    enum ReductionTarget {
42        /// Replace uses of a `DataInst` with a reduced `Value`.
43        DataInst(DataInst),
44
45        /// Replace an `OpSwitch` `ControlNode` with an `if`-`else` one.
46        //
47        // HACK(eddyb) see comment in `handle_control_node` for more details.
48        SwitchToIfElse(ControlNode),
49    }
50
51    loop {
52        let old_value_replacements_len = value_replacements.len();
53
54        // HACK(eddyb) we want to transform `DataInstDef`s, while having the ability
55        // to (mutably) traverse the function, but `in_place_transform_data_inst_def`
56        // only gives us a `&mut DataInstDef` (without the `FuncAtMut` around it).
57        //
58        // HACK(eddyb) ignore the above, for now it's pretty bad due to iterator
59        // invalidation (see comment on `let reduction_queue` too).
60        let mut handle_control_node =
61            |func_at_control_node: FuncAt<'_, ControlNode>| match func_at_control_node.def() {
62                &ControlNodeDef {
63                    kind: ControlNodeKind::Block { insts },
64                    ..
65                } => {
66                    for func_at_inst in func_at_control_node.at(insts) {
67                        if let Ok(redu) = Reducible::try_from((cx, func_at_inst.def())) {
68                            let redu_target = ReductionTarget::DataInst(func_at_inst.position);
69                            reduction_queue.push((redu_target, redu));
70                        }
71                    }
72                }
73
74                ControlNodeDef {
75                    kind:
76                        ControlNodeKind::Select {
77                            kind,
78                            scrutinee,
79                            cases,
80                        },
81                    outputs,
82                } => {
83                    // FIXME(eddyb) this should probably be ran in the queue loop
84                    // below, to more quickly benefit from previous reductions.
85                    for i in 0..u32::try_from(outputs.len()).unwrap() {
86                        let output = Value::ControlNodeOutput {
87                            control_node: func_at_control_node.position,
88                            output_idx: i,
89                        };
90                        if let Entry::Vacant(entry) = value_replacements.entry(output) {
91                            let per_case_value = cases.iter().map(|&case| {
92                                func_at_control_node.at(case).def().outputs[i as usize]
93                            });
94                            if let Some(reduced) = try_reduce_select(
95                                cx,
96                                &parent_map,
97                                func_at_control_node.position,
98                                kind,
99                                *scrutinee,
100                                per_case_value,
101                            ) {
102                                entry.insert(reduced);
103                            }
104                        }
105                    }
106
107                    // HACK(eddyb) turn `switch x { case 0: A; case 1: B; default: ... }`
108                    // into `if ... {B} else {A}`, when `x` ends up limited in `0..=1`,
109                    // (such `switch`es come from e.g. `match`-ing enums w/ 2 variants)
110                    // allowing us to bypass SPIR-T current (and temporary) lossiness
111                    // wrt `default: OpUnreachable` (i.e. we prove the `default:` can't
112                    // be entered based on `x` not having values other than `0` or `1`)
113                    if let SelectionKind::SpvInst(spv_inst) = kind
114                        && spv_inst.opcode == wk.OpSwitch
115                        && cases.len() == 3
116                    {
117                        // FIXME(eddyb) this kind of `OpSwitch` decoding logic should
118                        // be done by SPIR-T ahead of time, not here.
119                        let num_logical_imms = cases.len() - 1;
120                        assert_eq!(spv_inst.imms.len() % num_logical_imms, 0);
121                        let logical_imm_size = spv_inst.imms.len() / num_logical_imms;
122                        // FIXME(eddyb) collect to array instead.
123                        let logical_imms_as_u32s: SmallVec<[_; 2]> = spv_inst
124                            .imms
125                            .chunks(logical_imm_size)
126                            .map(spv_imm_checked_trunc32)
127                            .collect();
128
129                        // FIMXE(eddyb) support more values than just `0..=1`.
130                        if logical_imms_as_u32s[..] == [Some(0), Some(1)] {
131                            let redu = Reducible {
132                                op: PureOp::IntToBool,
133                                output_type: cx.intern(TypeDef {
134                                    attrs: Default::default(),
135                                    kind: TypeKind::SpvInst {
136                                        spv_inst: wk.OpTypeBool.into(),
137                                        type_and_const_inputs: iter::empty().collect(),
138                                    },
139                                }),
140                                input: *scrutinee,
141                            };
142                            let redu_target =
143                                ReductionTarget::SwitchToIfElse(func_at_control_node.position);
144                            reduction_queue.push((redu_target, redu));
145                        }
146                    }
147                }
148
149                ControlNodeDef {
150                    kind:
151                        ControlNodeKind::Loop {
152                            body,
153                            initial_inputs,
154                            ..
155                        },
156                    ..
157                } => {
158                    // FIXME(eddyb) this should probably be ran in the queue loop
159                    // below, to more quickly benefit from previous reductions.
160                    let body_outputs = &func_at_control_node.at(*body).def().outputs;
161                    for (i, (&initial_input, &body_output)) in
162                        initial_inputs.iter().zip(body_outputs).enumerate()
163                    {
164                        let body_input = Value::ControlRegionInput {
165                            region: *body,
166                            input_idx: i as u32,
167                        };
168                        if body_output == body_input {
169                            value_replacements
170                                .entry(body_input)
171                                .or_insert(initial_input);
172                        }
173                    }
174                }
175
176                &ControlNodeDef {
177                    kind: ControlNodeKind::ExitInvocation { .. },
178                    ..
179                } => {}
180            };
181        func_def_body.inner_visit_with(&mut VisitAllControlRegionsAndNodes {
182            state: (),
183            visit_control_region: |_: &mut (), _| {},
184            visit_control_node: |_: &mut (), func_at_control_node| {
185                handle_control_node(func_at_control_node);
186            },
187        });
188
189        // FIXME(eddyb) should this loop become the only loop, by having loop
190        // reductions push the new instruction to `reduction_queue`? the problem
191        // then is that it's not trivial to figure out what else might benefit
192        // from another full scan, so perhaps the only solution is "demand-driven"
193        // (recursing into use->def, instead of processing defs).
194        let mut any_changes = false;
195        for (redu_target, redu) in reduction_queue.drain(..) {
196            if let Some(v) = redu.try_reduce(
197                cx,
198                func_def_body.at_mut(()),
199                &value_replacements,
200                &parent_map,
201                &mut reduction_cache,
202            ) {
203                any_changes = true;
204                match redu_target {
205                    ReductionTarget::DataInst(inst) => {
206                        value_replacements.insert(Value::DataInstOutput(inst), v);
207
208                        // Replace the reduced `DataInstDef` itself with `OpNop`,
209                        // removing the ability to use its "name" as a value.
210                        //
211                        // FIXME(eddyb) cache the interned `OpNop`.
212                        *func_def_body.at_mut(inst).def() = DataInstDef {
213                            attrs: Default::default(),
214                            form: cx.intern(DataInstFormDef {
215                                kind: DataInstKind::SpvInst(wk.OpNop.into()),
216                                output_type: None,
217                            }),
218                            inputs: iter::empty().collect(),
219                        };
220                    }
221
222                    // HACK(eddyb) see comment in `handle_control_node` for more details.
223                    ReductionTarget::SwitchToIfElse(control_node) => {
224                        let control_node_def = func_def_body.at_mut(control_node).def();
225                        match &control_node_def.kind {
226                            ControlNodeKind::Select { cases, .. } => match cases[..] {
227                                [_default, case_0, case_1] => {
228                                    control_node_def.kind = ControlNodeKind::Select {
229                                        kind: SelectionKind::BoolCond,
230                                        scrutinee: v,
231                                        cases: [case_1, case_0].iter().copied().collect(),
232                                    };
233                                }
234                                _ => unreachable!(),
235                            },
236                            _ => unreachable!(),
237                        }
238                    }
239                }
240            }
241        }
242
243        if !any_changes && old_value_replacements_len == value_replacements.len() {
244            break;
245        }
246
247        func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|mut v| {
248            let old = v;
249            loop {
250                match v {
251                    Value::Const(_) => break,
252                    _ => match value_replacements.get(&v) {
253                        Some(&new) => v = new,
254                        None => break,
255                    },
256                }
257            }
258            if v != old {
259                any_changes = true;
260                Some(v)
261            } else {
262                None
263            }
264        }));
265    }
266}
267
268// FIXME(eddyb) maybe this kind of "parent map" should be provided by SPIR-T?
269#[derive(Default)]
270struct ParentMap {
271    data_inst_parent: EntityOrientedDenseMap<DataInst, ControlNode>,
272    control_node_parent: EntityOrientedDenseMap<ControlNode, ControlRegion>,
273    control_region_parent: EntityOrientedDenseMap<ControlRegion, ControlNode>,
274}
275
276impl ParentMap {
277    fn new(func_def_body: &FuncDefBody) -> Self {
278        let mut visitor = VisitAllControlRegionsAndNodes {
279            state: Self::default(),
280            visit_control_region:
281                |this: &mut Self, func_at_control_region: FuncAt<'_, ControlRegion>| {
282                    for func_at_child_control_node in func_at_control_region.at_children() {
283                        this.control_node_parent.insert(
284                            func_at_child_control_node.position,
285                            func_at_control_region.position,
286                        );
287                    }
288                },
289            visit_control_node: |this: &mut Self, func_at_control_node: FuncAt<'_, ControlNode>| {
290                let child_regions = match &func_at_control_node.def().kind {
291                    &ControlNodeKind::Block { insts } => {
292                        for func_at_inst in func_at_control_node.at(insts) {
293                            this.data_inst_parent
294                                .insert(func_at_inst.position, func_at_control_node.position);
295                        }
296                        &[][..]
297                    }
298
299                    ControlNodeKind::Select { cases, .. } => cases,
300                    ControlNodeKind::Loop { body, .. } => slice::from_ref(body),
301                    ControlNodeKind::ExitInvocation { .. } => &[][..],
302                };
303                for &child_region in child_regions {
304                    this.control_region_parent
305                        .insert(child_region, func_at_control_node.position);
306                }
307            },
308        };
309        func_def_body.inner_visit_with(&mut visitor);
310        visitor.state
311    }
312}
313
314/// If possible, find a single `Value` from `cases` (or even `scrutinee`),
315/// which would always be a valid result for `Select(kind, scrutinee, cases)`,
316/// regardless of which case gets (dynamically) taken.
317fn try_reduce_select(
318    cx: &Context,
319    parent_map: &ParentMap,
320    select_control_node: ControlNode,
321    // FIXME(eddyb) are these redundant with the `ControlNode` above?
322    kind: &SelectionKind,
323    scrutinee: Value,
324    cases: impl Iterator<Item = Value>,
325) -> Option<Value> {
326    let wk = &super::SpvSpecWithExtras::get().well_known;
327
328    let as_spv_const = |v: Value| match v {
329        Value::Const(ct) => match &cx[ct].kind {
330            ConstKind::SpvInst {
331                spv_inst_and_const_inputs,
332            } => {
333                let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs;
334                Some(spv_inst.opcode)
335            }
336            _ => None,
337        },
338        _ => None,
339    };
340
341    // Ignore `OpUndef`s, as they can be legally substituted with any other value.
342    let mut first_undef = None;
343    let mut non_undef_cases = cases.filter(|&case| {
344        let is_undef = as_spv_const(case) == Some(wk.OpUndef);
345        if is_undef && first_undef.is_none() {
346            first_undef = Some(case);
347        }
348        !is_undef
349    });
350    match (non_undef_cases.next(), non_undef_cases.next()) {
351        (None, _) => first_undef,
352
353        // `Select(c: bool, true, false)` can be replaced with just `c`.
354        (Some(x), Some(y))
355            if matches!(kind, SelectionKind::BoolCond)
356                && as_spv_const(x) == Some(wk.OpConstantTrue)
357                && as_spv_const(y) == Some(wk.OpConstantFalse) =>
358        {
359            assert!(non_undef_cases.next().is_none() && first_undef.is_none());
360
361            Some(scrutinee)
362        }
363
364        (Some(x), y) => {
365            if y.into_iter().chain(non_undef_cases).all(|z| z == x) {
366                // HACK(eddyb) closure here serves as `try` block.
367                let is_x_valid_outside_select = || {
368                    // Constants are always valid.
369                    if let Value::Const(_) = x {
370                        return Some(());
371                    }
372
373                    // HACK(eddyb) if the same value appears in two different
374                    // cases, it's definitely dominating the whole `Select`.
375                    if y.is_some() {
376                        return Some(());
377                    }
378
379                    // In general, `x` dominating the `Select` is what would
380                    // allow lifting an use of it outside the `Select`.
381                    let region_defining_x = match x {
382                        Value::Const(_) => unreachable!(),
383                        Value::ControlRegionInput { region, .. } => region,
384                        Value::ControlNodeOutput { control_node, .. } => {
385                            *parent_map.control_node_parent.get(control_node)?
386                        }
387                        Value::DataInstOutput(inst) => *parent_map
388                            .control_node_parent
389                            .get(*parent_map.data_inst_parent.get(inst)?)?,
390                    };
391
392                    // Fast-reject: if `x` is defined immediately inside one of
393                    // `select_control_node`'s cases, it's not a dominator.
394                    if parent_map.control_region_parent.get(region_defining_x)
395                        == Some(&select_control_node)
396                    {
397                        return None;
398                    }
399
400                    // Since we know `x` is used inside the `Select`, this only
401                    // needs to check that `x` is defined in a region that the
402                    // `Select` is nested in, as the only other possibility is
403                    // that the `x` is defined inside the `Select` - that is,
404                    // one of `x` and `Select` always dominates the other.
405                    //
406                    // FIXME(eddyb) this could be more efficient with some kind
407                    // of "region depth" precomputation but a potentially-slower
408                    // check doubles as a sanity check, for now.
409                    let mut region_containing_select =
410                        *parent_map.control_node_parent.get(select_control_node)?;
411                    loop {
412                        if region_containing_select == region_defining_x {
413                            return Some(());
414                        }
415                        region_containing_select = *parent_map.control_node_parent.get(
416                            *parent_map
417                                .control_region_parent
418                                .get(region_containing_select)?,
419                        )?;
420                    }
421                };
422                if is_x_valid_outside_select().is_some() {
423                    return Some(x);
424                }
425            }
426
427            None
428        }
429    }
430}
431
432/// Pure operation that transforms one `Value` into another `Value`.
433//
434// FIXME(eddyb) move this elsewhere? also, how should binops etc. be supported?
435// (one approach could be having a "focus input" that can be dynamic, with the
436// other inputs being `Const`s, i.e. partially applying all but one input)
437#[derive(Copy, Clone, PartialEq, Eq, Hash)]
438enum PureOp {
439    BitCast,
440    CompositeExtract {
441        elem_idx: spv::Imm,
442    },
443
444    /// Maps `0` to `false`, and `1` to `true`, but any other input values won't
445    /// allow reduction, which is used to signal `0..=1` isn't being guaranteed.
446    //
447    // HACK(eddyb) not a real operation, but a helper used to extract a `bool`
448    // equivalent for an `OpSwitch`'s scrutinee.
449    // FIXME(eddyb) proper SPIR-T range analysis should be implemented and such
450    // a reduction not attempted at all if the range is larger than `0..=1`
451    // (also, the actual operation can be replaced with `x == 1` or `x != 0`)
452    IntToBool,
453}
454
455impl TryFrom<&spv::Inst> for PureOp {
456    type Error = ();
457    fn try_from(spv_inst: &spv::Inst) -> Result<Self, ()> {
458        let wk = &super::SpvSpecWithExtras::get().well_known;
459
460        let op = spv_inst.opcode;
461        Ok(match spv_inst.imms[..] {
462            [] if op == wk.OpBitcast => Self::BitCast,
463
464            // FIXME(eddyb) support more than one index at a time, somehow.
465            [elem_idx] if op == wk.OpCompositeExtract => Self::CompositeExtract { elem_idx },
466
467            _ => return Err(()),
468        })
469    }
470}
471
472impl TryFrom<PureOp> for spv::Inst {
473    type Error = ();
474    fn try_from(op: PureOp) -> Result<Self, ()> {
475        let wk = &super::SpvSpecWithExtras::get().well_known;
476
477        let (opcode, imms) = match op {
478            PureOp::BitCast => (wk.OpBitcast, iter::empty().collect()),
479            PureOp::CompositeExtract { elem_idx } => {
480                (wk.OpCompositeExtract, iter::once(elem_idx).collect())
481            }
482
483            // HACK(eddyb) this is the only reason this is `TryFrom` not `From`.
484            PureOp::IntToBool => return Err(()),
485        };
486        Ok(Self { opcode, imms })
487    }
488}
489
490/// Potentially-reducible application of a `PureOp` (`op`) to `input`.
491#[derive(Copy, Clone, PartialEq, Eq, Hash)]
492struct Reducible<V = Value> {
493    op: PureOp,
494    output_type: Type,
495    input: V,
496}
497
498impl<V> Reducible<V> {
499    fn with_input<V2>(self, new_input: V2) -> Reducible<V2> {
500        Reducible {
501            op: self.op,
502            output_type: self.output_type,
503            input: new_input,
504        }
505    }
506}
507
508// FIXME(eddyb) instead of taking a `&Context`, could `Reducible` hold a `DataInstForm`?
509impl TryFrom<(&Context, &DataInstDef)> for Reducible {
510    type Error = ();
511    fn try_from((cx, inst_def): (&Context, &DataInstDef)) -> Result<Self, ()> {
512        let inst_form_def = &cx[inst_def.form];
513        if let DataInstKind::SpvInst(spv_inst) = &inst_form_def.kind {
514            let op = PureOp::try_from(spv_inst)?;
515            let output_type = inst_form_def.output_type.unwrap();
516            if let [input] = inst_def.inputs[..] {
517                return Ok(Self {
518                    op,
519                    output_type,
520                    input,
521                });
522            }
523        }
524        Err(())
525    }
526}
527
528impl Reducible {
529    // HACK(eddyb) `IntToBool` is the only reason this can return `None`.
530    fn try_into_inst(self, cx: &Context) -> Option<DataInstDef> {
531        let Self {
532            op,
533            output_type,
534            input,
535        } = self;
536        Some(DataInstDef {
537            attrs: Default::default(),
538            form: cx.intern(DataInstFormDef {
539                kind: DataInstKind::SpvInst(op.try_into().ok()?),
540                output_type: Some(output_type),
541            }),
542            inputs: iter::once(input).collect(),
543        })
544    }
545}
546
547/// Returns `Some(lowest32)` iff `imms` contains one *logical* SPIR-V immediate
548/// representing a (little-endian) integer which truncates (if wider than 32 bits)
549/// to `lowest32`, losslessly (i.e. the rest of the bits are all zeros).
550//
551// FIXME(eddyb) move this into some kind of utility/common helpers place.
552fn spv_imm_checked_trunc32(imms: &[spv::Imm]) -> Option<u32> {
553    match imms {
554        &[spv::Imm::Short(_, lowest32)] | &[spv::Imm::LongStart(_, lowest32), ..]
555            if imms[1..]
556                .iter()
557                .all(|imm| matches!(imm, spv::Imm::LongCont(_, 0))) =>
558        {
559            Some(lowest32)
560        }
561        _ => None,
562    }
563}
564
565impl Reducible<Const> {
566    // FIXME(eddyb) in theory this should always return `Some`.
567    fn try_reduce_const(&self, cx: &Context) -> Option<Const> {
568        let wk = &super::SpvSpecWithExtras::get().well_known;
569
570        let ct_def = &cx[self.input];
571        match (self.op, &ct_def.kind) {
572            (
573                _,
574                ConstKind::SpvInst {
575                    spv_inst_and_const_inputs,
576                },
577            ) if spv_inst_and_const_inputs.0.opcode == wk.OpUndef => Some(cx.intern(ConstDef {
578                attrs: ct_def.attrs,
579                ty: self.output_type,
580                kind: ct_def.kind.clone(),
581            })),
582
583            (
584                PureOp::BitCast,
585                ConstKind::SpvInst {
586                    spv_inst_and_const_inputs,
587                },
588            ) if spv_inst_and_const_inputs.0.opcode == wk.OpConstant => {
589                // `OpTypeInt`/`OpTypeFloat` bit width.
590                let scalar_width = |ty: Type| match &cx[ty].kind {
591                    TypeKind::SpvInst { spv_inst, .. }
592                        if [wk.OpTypeInt, wk.OpTypeFloat].contains(&spv_inst.opcode) =>
593                    {
594                        Some(spv_inst.imms[0])
595                    }
596                    _ => None,
597                };
598
599                match (scalar_width(ct_def.ty), scalar_width(self.output_type)) {
600                    (Some(from), Some(to)) if from == to => Some(cx.intern(ConstDef {
601                        attrs: ct_def.attrs,
602                        ty: self.output_type,
603                        kind: ct_def.kind.clone(),
604                    })),
605                    _ => None,
606                }
607            }
608
609            (
610                PureOp::CompositeExtract {
611                    elem_idx: spv::Imm::Short(_, elem_idx),
612                },
613                ConstKind::SpvInst {
614                    spv_inst_and_const_inputs,
615                },
616            ) if spv_inst_and_const_inputs.0.opcode == wk.OpConstantComposite => {
617                let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs;
618                Some(const_inputs[elem_idx as usize])
619            }
620
621            (
622                PureOp::IntToBool,
623                ConstKind::SpvInst {
624                    spv_inst_and_const_inputs,
625                },
626            ) if spv_inst_and_const_inputs.0.opcode == wk.OpConstant => {
627                let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs;
628                let bool_const_op = match spv_imm_checked_trunc32(&spv_inst.imms[..]) {
629                    Some(0) => wk.OpConstantFalse,
630                    Some(1) => wk.OpConstantTrue,
631                    _ => return None,
632                };
633                Some(cx.intern(ConstDef {
634                    attrs: Default::default(),
635                    ty: self.output_type,
636                    kind: ConstKind::SpvInst {
637                        spv_inst_and_const_inputs: Rc::new((
638                            bool_const_op.into(),
639                            iter::empty().collect(),
640                        )),
641                    },
642                }))
643            }
644
645            _ => None,
646        }
647    }
648}
649
650/// Outcome of a single step of a reduction (which may require more steps).
651enum ReductionStep {
652    Complete(Value),
653    Partial(Reducible),
654}
655
656impl Reducible<&DataInstDef> {
657    // FIXME(eddyb) force the input to actually be itself some kind of pure op.
658    fn try_reduce_output_of_data_inst(&self, cx: &Context) -> Option<ReductionStep> {
659        let wk = &super::SpvSpecWithExtras::get().well_known;
660
661        let input_inst_def = self.input;
662        if let DataInstKind::SpvInst(input_spv_inst) = &cx[input_inst_def.form].kind {
663            // NOTE(eddyb) do not destroy information left in e.g. comments.
664            #[allow(clippy::match_same_arms)]
665            match self.op {
666                PureOp::BitCast => {
667                    // FIXME(eddyb) reduce chains of bitcasts.
668                }
669
670                PureOp::CompositeExtract { elem_idx } => {
671                    if input_spv_inst.opcode == wk.OpCompositeInsert
672                        && input_spv_inst.imms.len() == 1
673                    {
674                        let new_elem = input_inst_def.inputs[0];
675                        let prev_composite = input_inst_def.inputs[1];
676                        return Some(if input_spv_inst.imms[0] == elem_idx {
677                            ReductionStep::Complete(new_elem)
678                        } else {
679                            ReductionStep::Partial(self.with_input(prev_composite))
680                        });
681                    }
682                }
683
684                PureOp::IntToBool => {
685                    // FIXME(eddyb) look into what instructions might end up
686                    // being used to transform booleans into integers.
687                }
688            }
689        }
690
691        None
692    }
693}
694
695impl Reducible {
696    // FIXME(eddyb) make this into some kind of local `ReduceCx` method.
697    fn try_reduce(
698        mut self,
699        cx: &Context,
700        // FIXME(eddyb) come up with a better convention for this!
701        func: FuncAtMut<'_, ()>,
702
703        value_replacements: &FxHashMap<Value, Value>,
704
705        parent_map: &ParentMap,
706
707        cache: &mut FxHashMap<Self, Option<Value>>,
708    ) -> Option<Value> {
709        // FIXME(eddyb) should we care about the cache *before* this loop below?
710
711        // HACK(eddyb) eagerly apply `value_replacements`.
712        // FIXME(eddyb) this could do the union-find trick of shortening chains
713        // the first time they're encountered, but also, if this process was more
714        // "demand-driven" (recursing into use->def, instead of processing defs),
715        // it might not require any of this complication.
716        while let Some(&replacement) = value_replacements.get(&self.input) {
717            self.input = replacement;
718        }
719
720        if let Some(&cached) = cache.get(&self) {
721            return cached;
722        }
723
724        let result = self.try_reduce_uncached(cx, func, value_replacements, parent_map, cache);
725
726        cache.insert(self, result);
727
728        result
729    }
730
731    // FIXME(eddyb) make this into some kind of local `ReduceCx` method.
732    fn try_reduce_uncached(
733        self,
734        cx: &Context,
735        // FIXME(eddyb) come up with a better convention for this!
736        mut func: FuncAtMut<'_, ()>,
737
738        value_replacements: &FxHashMap<Value, Value>,
739
740        parent_map: &ParentMap,
741
742        cache: &mut FxHashMap<Self, Option<Value>>,
743    ) -> Option<Value> {
744        match self.input {
745            Value::Const(ct) => self.with_input(ct).try_reduce_const(cx).map(Value::Const),
746            Value::ControlRegionInput {
747                region,
748                input_idx: state_idx,
749            } => {
750                let loop_node = *parent_map.control_region_parent.get(region)?;
751                // HACK(eddyb) this can't be a closure due to lifetime elision.
752                fn loop_initial_states(
753                    func_at_loop_node: FuncAtMut<'_, ControlNode>,
754                ) -> &mut SmallVec<[Value; 2]> {
755                    match &mut func_at_loop_node.def().kind {
756                        ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
757                        _ => unreachable!(),
758                    }
759                }
760
761                let input_from_initial_state =
762                    loop_initial_states(func.reborrow().at(loop_node))[state_idx as usize];
763                let input_from_updated_state =
764                    func.reborrow().at(region).def().outputs[state_idx as usize];
765
766                let output_from_initial_state = self
767                    .with_input(input_from_initial_state)
768                    .try_reduce(cx, func.reborrow(), value_replacements, parent_map, cache)?;
769                // HACK(eddyb) this is here because it can fail, see the comment
770                // on `output_from_updated_state` for what's actually going on.
771                let output_from_updated_state_inst = self
772                    .with_input(input_from_updated_state)
773                    .try_into_inst(cx)?;
774
775                // Now that the reduction succeeded for the initial state,
776                // we can proceed with augmenting the loop with the extra state.
777                loop_initial_states(func.reborrow().at(loop_node)).push(output_from_initial_state);
778
779                let loop_state_decls = &mut func.reborrow().at(region).def().inputs;
780                let new_loop_state_idx = u32::try_from(loop_state_decls.len()).unwrap();
781                loop_state_decls.push(ControlRegionInputDecl {
782                    attrs: Default::default(),
783                    ty: self.output_type,
784                });
785
786                // HACK(eddyb) generating the instruction wholesale again is not
787                // the most efficient way to go about this, but avoiding getting
788                // stuck in a loop while processing a loop is also important.
789                //
790                // FIXME(eddyb) attempt to replace this with early-inserting in
791                // `cache` *then* returning.
792                let output_from_updated_state = func
793                    .data_insts
794                    .define(cx, output_from_updated_state_inst.into());
795                func.reborrow()
796                    .at(region)
797                    .def()
798                    .outputs
799                    .push(Value::DataInstOutput(output_from_updated_state));
800
801                // FIXME(eddyb) move this into some kind of utility/common helpers.
802                let loop_body_last_block = func
803                    .reborrow()
804                    .at(region)
805                    .def()
806                    .children
807                    .iter()
808                    .last
809                    .filter(|&node| {
810                        matches!(
811                            func.reborrow().at(node).def().kind,
812                            ControlNodeKind::Block { .. }
813                        )
814                    })
815                    .unwrap_or_else(|| {
816                        let new_block = func.control_nodes.define(
817                            cx,
818                            ControlNodeDef {
819                                kind: ControlNodeKind::Block {
820                                    insts: Default::default(),
821                                },
822                                outputs: Default::default(),
823                            }
824                            .into(),
825                        );
826                        func.control_regions[region]
827                            .children
828                            .insert_last(new_block, func.control_nodes);
829                        new_block
830                    });
831                match &mut func.control_nodes[loop_body_last_block].kind {
832                    ControlNodeKind::Block { insts } => {
833                        insts.insert_last(output_from_updated_state, func.data_insts);
834                    }
835                    _ => unreachable!(),
836                }
837
838                Some(Value::ControlRegionInput {
839                    region,
840                    input_idx: new_loop_state_idx,
841                })
842            }
843            Value::ControlNodeOutput {
844                control_node,
845                output_idx,
846            } => {
847                let cases = match &func.reborrow().at(control_node).def().kind {
848                    ControlNodeKind::Select { cases, .. } => cases,
849                    // NOTE(eddyb) only `Select`s can have outputs right now.
850                    _ => unreachable!(),
851                };
852
853                // FIXME(eddyb) remove all the cloning and undo additions of new
854                // outputs "upstream", if they end up unused (or let DCE do it?).
855                let cases = cases.clone();
856                let per_case_new_output: SmallVec<[_; 2]> = cases
857                    .iter()
858                    .map(|&case| {
859                        let per_case_input =
860                            func.reborrow().at(case).def().outputs[output_idx as usize];
861                        self.with_input(per_case_input).try_reduce(
862                            cx,
863                            func.reborrow(),
864                            value_replacements,
865                            parent_map,
866                            cache,
867                        )
868                    })
869                    .collect::<Option<_>>()?;
870
871                // Try to avoid introducing a new output, by reducing the merge
872                // of the per-case output values to a single value, if possible.
873                let (kind, scrutinee) = match &func.reborrow().at(control_node).def().kind {
874                    ControlNodeKind::Select {
875                        kind, scrutinee, ..
876                    } => (kind, *scrutinee),
877                    _ => unreachable!(),
878                };
879                if let Some(v) = try_reduce_select(
880                    cx,
881                    parent_map,
882                    control_node,
883                    kind,
884                    scrutinee,
885                    per_case_new_output.iter().copied(),
886                ) {
887                    return Some(v);
888                }
889
890                // Merge the per-case output values into a new output.
891                let control_node_output_decls = &mut func.reborrow().at(control_node).def().outputs;
892                let new_output_idx = u32::try_from(control_node_output_decls.len()).unwrap();
893                control_node_output_decls.push(ControlNodeOutputDecl {
894                    attrs: Default::default(),
895                    ty: self.output_type,
896                });
897                for (&case, new_output) in cases.iter().zip(per_case_new_output) {
898                    let per_case_outputs = &mut func.reborrow().at(case).def().outputs;
899                    assert_eq!(per_case_outputs.len(), new_output_idx as usize);
900                    per_case_outputs.push(new_output);
901                }
902                Some(Value::ControlNodeOutput {
903                    control_node,
904                    output_idx: new_output_idx,
905                })
906            }
907            Value::DataInstOutput(inst) => {
908                let inst_def = &*func.reborrow().at(inst).def();
909                match self
910                    .with_input(inst_def)
911                    .try_reduce_output_of_data_inst(cx)?
912                {
913                    ReductionStep::Complete(v) => Some(v),
914                    // FIXME(eddyb) actually use a loop instead of recursing here.
915                    ReductionStep::Partial(redu) => {
916                        redu.try_reduce(cx, func, value_replacements, parent_map, cache)
917                    }
918                }
919            }
920        }
921    }
922}