1use super::id;
2use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
3use rspirv::spirv::{Op, Word};
4use rustc_data_structures::fx::{FxHashMap, FxHashSet};
5use rustc_middle::bug;
6
7pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
8    module
9        .types_global_values
10        .iter()
11        .filter_map(|inst| Some((inst.result_id?, inst.clone())))
12        .collect()
13}
14
15fn composite_count(types: &FxHashMap<Word, Instruction>, ty_id: Word) -> Option<usize> {
16    let ty = types.get(&ty_id)?;
17    match ty.class.opcode {
18        Op::TypeStruct => Some(ty.operands.len()),
19        Op::TypeVector => Some(ty.operands[1].unwrap_literal_bit32() as usize),
20        Op::TypeArray => {
21            let length_id = ty.operands[1].unwrap_id_ref();
22            let const_inst = types.get(&length_id)?;
23            if const_inst.class.opcode != Op::Constant {
24                return None;
25            }
26            let const_ty = types.get(&const_inst.result_type.unwrap())?;
27            if const_ty.class.opcode != Op::TypeInt {
28                return None;
29            }
30            let const_value = match const_inst.operands[0] {
31                Operand::LiteralBit32(v) => v as usize,
32                Operand::LiteralBit64(v) => v as usize,
33                _ => bug!(),
34            };
35            Some(const_value)
36        }
37        _ => None,
38    }
39}
40
41pub fn composite_construct(types: &FxHashMap<Word, Instruction>, function: &mut Function) {
44    let defs = function
45        .all_inst_iter()
46        .filter_map(|inst| Some((inst.result_id?, inst.clone())))
47        .collect::<FxHashMap<Word, Instruction>>();
48    for block in &mut function.blocks {
49        for inst in &mut block.instructions {
50            if inst.class.opcode != Op::CompositeInsert {
51                continue;
52            }
53            let component_count = match composite_count(types, inst.result_type.unwrap()) {
55                Some(c) => c,
56                None => continue,
57            };
58            let mut components = vec![None; component_count];
62            let mut cur_inst: &Instruction = inst;
63            while cur_inst.class.opcode == Op::CompositeInsert {
65                if cur_inst.operands.len() != 3 {
66                    break;
68                }
69                let value = cur_inst.operands[0].unwrap_id_ref();
70                let index = cur_inst.operands[2].unwrap_literal_bit32() as usize;
71                if index >= components.len() {
72                    break;
75                }
76                if components[index].is_none() {
77                    components[index] = Some(value);
78                }
79                cur_inst = match defs.get(&cur_inst.operands[1].unwrap_id_ref()) {
81                    Some(i) => i,
82                    None => break,
83                };
84            }
85            if let Some(composite_construct_operands) = components
88                .into_iter()
89                .map(|v| v.map(Operand::IdRef))
90                .collect::<Option<Vec<_>>>()
91            {
92                *inst = Instruction::new(
95                    Op::CompositeConstruct,
96                    inst.result_type,
97                    inst.result_id,
98                    composite_construct_operands,
99                );
100            }
101        }
102    }
103}
104
105#[derive(Debug)]
106enum IdentifiedOperand {
107    Vector(Word),
109    Scalars(Vec<Word>),
112    NonValue(Operand),
115}
116
117fn get_composite_and_index(
120    types: &FxHashMap<Word, Instruction>,
121    defs: &FxHashMap<Word, Instruction>,
122    id: Word,
123    vector_width: u32,
124) -> Option<(Word, u32)> {
125    let inst = defs.get(&id)?;
126    if inst.class.opcode != Op::CompositeExtract {
127        return None;
128    }
129    if inst.operands.len() != 2 {
130        return None;
132    }
133    let composite = inst.operands[0].unwrap_id_ref();
134    let index = inst.operands[1].unwrap_literal_bit32();
135
136    let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?;
137    let vector_def = types.get(&composite_def.result_type.unwrap())?;
138
139    if vector_def.class.opcode != Op::TypeVector
143        || vector_width != vector_def.operands[1].unwrap_literal_bit32()
144    {
145        return None;
146    }
147
148    Some((composite, index))
149}
150
151fn match_vector_operand(
155    types: &FxHashMap<Word, Instruction>,
156    defs: &FxHashMap<Word, Instruction>,
157    results: &[&Instruction],
158    operand_index: usize,
159    vector_width: u32,
160) -> Option<Word> {
161    let operand_zero = match results[0].operands[operand_index] {
162        Operand::IdRef(id) => id,
163        _ => {
164            return None;
165        }
166    };
167    let composite_zero = match get_composite_and_index(types, defs, operand_zero, vector_width) {
169        Some((composite_zero, 0)) => composite_zero,
170        _ => {
171            return None;
172        }
173    };
174    for (expected_index, result) in results.iter().enumerate().skip(1) {
176        let operand = match result.operands[operand_index] {
177            Operand::IdRef(id) => id,
178            _ => {
179                return None;
180            }
181        };
182        let (composite, actual_index) =
183            match get_composite_and_index(types, defs, operand, vector_width) {
184                Some(x) => x,
185                None => {
186                    return None;
187                }
188            };
189        if composite != composite_zero || expected_index != actual_index as usize {
192            return None;
193        }
194    }
195    Some(composite_zero)
196}
197
198fn match_vector_or_scalars_operand(
202    types: &FxHashMap<Word, Instruction>,
203    defs: &FxHashMap<Word, Instruction>,
204    results: &[&Instruction],
205    operand_index: usize,
206    vector_width: u32,
207) -> Option<IdentifiedOperand> {
208    if let Some(composite) = match_vector_operand(types, defs, results, operand_index, vector_width)
209    {
210        Some(IdentifiedOperand::Vector(composite))
211    } else {
212        let operands = results
213            .iter()
214            .map(|inst| match inst.operands[operand_index] {
215                Operand::IdRef(id) => Some(id),
216                _ => None,
217            })
218            .collect::<Option<Vec<_>>>()?;
219        Some(IdentifiedOperand::Scalars(operands))
220    }
221}
222
223fn match_all_same_operand(results: &[&Instruction], operand_index: usize) -> Option<Operand> {
226    let operand_zero = &results[0].operands[operand_index];
227    if results
228        .iter()
229        .skip(1)
230        .all(|inst| &inst.operands[operand_index] == operand_zero)
231    {
232        Some(operand_zero.clone())
233    } else {
234        None
235    }
236}
237
238fn match_operands(
241    types: &FxHashMap<Word, Instruction>,
242    defs: &FxHashMap<Word, Instruction>,
243    results: &[&Instruction],
244    vector_width: u32,
245) -> Option<Vec<IdentifiedOperand>> {
246    let operation_opcode = results[0].class.opcode;
247    if results.iter().skip(1).any(|r| {
249        r.class.opcode != operation_opcode || r.operands.len() != results[0].operands.len()
250    }) {
251        return None;
252    }
253    match operation_opcode {
255        Op::IAdd
256        | Op::FAdd
257        | Op::ISub
258        | Op::FSub
259        | Op::IMul
260        | Op::FMul
261        | Op::UDiv
262        | Op::SDiv
263        | Op::FDiv
264        | Op::UMod
265        | Op::SRem
266        | Op::FRem
267        | Op::FMod
268        | Op::ShiftRightLogical
269        | Op::ShiftRightArithmetic
270        | Op::ShiftLeftLogical
271        | Op::BitwiseOr
272        | Op::BitwiseXor
273        | Op::BitwiseAnd => {
274            let left = match_vector_or_scalars_operand(types, defs, results, 0, vector_width)?;
275            let right = match_vector_or_scalars_operand(types, defs, results, 1, vector_width)?;
276            match (left, right) {
277                (IdentifiedOperand::Scalars(_), IdentifiedOperand::Scalars(_)) => None,
279                (left, right) => Some(vec![left, right]),
280            }
281        }
282        Op::SNegate | Op::FNegate | Op::Not | Op::BitReverse => {
283            let value = match_vector_operand(types, defs, results, 0, vector_width)?;
284            Some(vec![IdentifiedOperand::Vector(value)])
285        }
286        Op::ExtInst => {
287            let set = match_all_same_operand(results, 0)?;
288            let instruction = match_all_same_operand(results, 1)?;
289            let parameters = (2..results[0].operands.len())
290                .map(|i| match_vector_or_scalars_operand(types, defs, results, i, vector_width));
291            let operands = IntoIterator::into_iter([
293                Some(IdentifiedOperand::NonValue(set)),
294                Some(IdentifiedOperand::NonValue(instruction)),
295            ])
296            .chain(parameters)
297            .collect::<Option<Vec<_>>>()?;
298            if operands
299                .iter()
300                .skip(2)
301                .all(|p| matches!(p, &IdentifiedOperand::Scalars(_)))
302            {
303                return None;
305            }
306            Some(operands)
307        }
308        _ => None,
309    }
310}
311
312fn process_instruction(
313    header: &mut ModuleHeader,
314    types: &FxHashMap<Word, Instruction>,
315    defs: &FxHashMap<Word, Instruction>,
316    instructions: &mut Vec<Instruction>,
317    instruction_index: &mut usize,
318) -> Option<Instruction> {
319    let inst = &instructions[*instruction_index];
320    if inst.class.opcode != Op::CompositeConstruct {
322        return None;
323    }
324    let inst_result_id = inst.result_id.unwrap();
325    let vector_ty = inst.result_type.unwrap();
326    let vector_ty_inst = match types.get(&vector_ty) {
327        Some(inst) => inst,
328        _ => return None,
329    };
330    if vector_ty_inst.class.opcode != Op::TypeVector {
331        return None;
332    }
333    let vector_width = vector_ty_inst.operands[1].unwrap_literal_bit32();
334    let results = inst
336        .operands
337        .iter()
338        .map(|op| defs.get(&op.unwrap_id_ref()))
339        .collect::<Option<Vec<_>>>()?;
340
341    let operation_opcode = results[0].class.opcode;
342    let composite_arguments = match_operands(types, defs, &results, vector_width)?;
344
345    if operation_opcode == Op::FMul
348        && composite_arguments.len() == 2
349        && let (&IdentifiedOperand::Vector(composite), IdentifiedOperand::Scalars(scalars))
350        | (IdentifiedOperand::Scalars(scalars), &IdentifiedOperand::Vector(composite)) =
351            (&composite_arguments[0], &composite_arguments[1])
352    {
353        let scalar = scalars[0];
354        if scalars.iter().skip(1).all(|&s| s == scalar) {
355            return Some(Instruction::new(
356                Op::VectorTimesScalar,
357                inst.result_type,
358                inst.result_id,
359                vec![Operand::IdRef(composite), Operand::IdRef(scalar)],
360            ));
361        }
362    }
363
364    let operands = composite_arguments
367        .into_iter()
368        .map(|operand| match operand {
369            IdentifiedOperand::Vector(composite) => Operand::IdRef(composite),
370            IdentifiedOperand::NonValue(operand) => operand,
371            IdentifiedOperand::Scalars(scalars) => {
372                let id = super::id(header);
373                instructions.insert(
376                    *instruction_index,
377                    Instruction::new(
378                        Op::CompositeConstruct,
379                        Some(vector_ty),
380                        Some(id),
381                        scalars.into_iter().map(Operand::IdRef).collect(),
382                    ),
383                );
384                *instruction_index += 1;
385                Operand::IdRef(id)
386            }
387        })
388        .collect();
389
390    Some(Instruction::new(
391        operation_opcode,
392        Some(vector_ty),
393        Some(inst_result_id),
394        operands,
395    ))
396}
397
398pub fn vector_ops(
415    header: &mut ModuleHeader,
416    types: &FxHashMap<Word, Instruction>,
417    function: &mut Function,
418) {
419    let defs = function
420        .all_inst_iter()
421        .filter_map(|inst| Some((inst.result_id?, inst.clone())))
422        .collect::<FxHashMap<Word, Instruction>>();
423    for block in &mut function.blocks {
424        let mut instruction_index = 0;
430        while instruction_index < block.instructions.len() {
431            if let Some(result) = process_instruction(
432                header,
433                types,
434                &defs,
435                &mut block.instructions,
436                &mut instruction_index,
437            ) {
438                block.instructions[instruction_index] = result;
441            }
442
443            instruction_index += 1;
444        }
445    }
446}
447
448fn can_fuse_bool(
449    types: &FxHashMap<Word, Instruction>,
450    defs: &FxHashMap<Word, (usize, Instruction)>,
451    inst: &Instruction,
452) -> bool {
453    fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
454        let inst = types.get(&val)?;
455        if inst.class.opcode != Op::Constant {
456            return None;
457        }
458        match inst.operands[0] {
459            Operand::LiteralBit32(v) => Some(v),
460            _ => None,
461        }
462    }
463
464    fn visit(
465        types: &FxHashMap<Word, Instruction>,
466        defs: &FxHashMap<Word, (usize, Instruction)>,
467        visited: &mut FxHashSet<Word>,
468        value: Word,
469    ) -> bool {
470        if visited.insert(value) {
471            let inst = match defs.get(&value) {
472                Some((_, inst)) => inst,
473                None => return false,
474            };
475            match inst.class.opcode {
476                Op::Select => {
477                    constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
478                        && constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
479                }
480                Op::Phi => inst
481                    .operands
482                    .iter()
483                    .step_by(2)
484                    .all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
485                _ => false,
486            }
487        } else {
488            true
489        }
490    }
491
492    if inst.class.opcode != Op::INotEqual
493        || constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
494    {
495        return false;
496    }
497    let int_value = inst.operands[0].unwrap_id_ref();
498
499    visit(types, defs, &mut FxHashSet::default(), int_value)
500}
501
502fn fuse_bool(
503    header: &mut ModuleHeader,
504    defs: &FxHashMap<Word, (usize, Instruction)>,
505    phis_to_insert: &mut Vec<(usize, Instruction)>,
506    already_mapped: &mut FxHashMap<Word, Word>,
507    bool_ty: Word,
508    int_value: Word,
509) -> Word {
510    if let Some(&result) = already_mapped.get(&int_value) {
511        return result;
512    }
513    let (block_of_inst, inst) = defs.get(&int_value).unwrap();
514    match inst.class.opcode {
515        Op::Select => inst.operands[0].unwrap_id_ref(),
516        Op::Phi => {
517            let result_id = id(header);
518            already_mapped.insert(int_value, result_id);
519            let new_phi_args = inst
520                .operands
521                .chunks(2)
522                .flat_map(|arr| {
523                    let phi_value = &arr[0];
524                    let block = &arr[1];
525                    [
526                        Operand::IdRef(fuse_bool(
527                            header,
528                            defs,
529                            phis_to_insert,
530                            already_mapped,
531                            bool_ty,
532                            phi_value.unwrap_id_ref(),
533                        )),
534                        block.clone(),
535                    ]
536                })
537                .collect::<Vec<_>>();
538            let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
539            phis_to_insert.push((*block_of_inst, inst));
540            result_id
541        }
542        _ => bug!("can_fuse_bool should have prevented this case"),
543    }
544}
545
546pub fn bool_fusion(
570    header: &mut ModuleHeader,
571    types: &FxHashMap<Word, Instruction>,
572    function: &mut Function,
573) {
574    let defs: FxHashMap<Word, (usize, Instruction)> = function
575        .blocks
576        .iter()
577        .enumerate()
578        .flat_map(|(block_id, block)| {
579            block
580                .instructions
581                .iter()
582                .filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
583        })
584        .collect();
585    let mut rewrite_rules = FxHashMap::default();
586    let mut phis_to_insert = Default::default();
587    let mut already_mapped = Default::default();
588    for block in &mut function.blocks {
589        for inst in &mut block.instructions {
590            if can_fuse_bool(types, &defs, inst) {
591                let rewrite_to = fuse_bool(
592                    header,
593                    &defs,
594                    &mut phis_to_insert,
595                    &mut already_mapped,
596                    inst.result_type.unwrap(),
597                    inst.operands[0].unwrap_id_ref(),
598                );
599                rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
600                *inst = Instruction::new(Op::Nop, None, None, Vec::new());
601            }
602        }
603    }
604    for (block, phi) in phis_to_insert {
605        function.blocks[block].instructions.insert(0, phi);
606    }
607    super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
608}