1use super::id;
2use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
3use rspirv::spirv::{Op, Word};
4use rustc_data_structures::fx::{FxHashMap, FxHashSet};
5use rustc_middle::bug;
6
7pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
8 module
9 .types_global_values
10 .iter()
11 .filter_map(|inst| Some((inst.result_id?, inst.clone())))
12 .collect()
13}
14
15fn composite_count(types: &FxHashMap<Word, Instruction>, ty_id: Word) -> Option<usize> {
16 let ty = types.get(&ty_id)?;
17 match ty.class.opcode {
18 Op::TypeStruct => Some(ty.operands.len()),
19 Op::TypeVector => Some(ty.operands[1].unwrap_literal_bit32() as usize),
20 Op::TypeArray => {
21 let length_id = ty.operands[1].unwrap_id_ref();
22 let const_inst = types.get(&length_id)?;
23 if const_inst.class.opcode != Op::Constant {
24 return None;
25 }
26 let const_ty = types.get(&const_inst.result_type.unwrap())?;
27 if const_ty.class.opcode != Op::TypeInt {
28 return None;
29 }
30 let const_value = match const_inst.operands[0] {
31 Operand::LiteralBit32(v) => v as usize,
32 Operand::LiteralBit64(v) => v as usize,
33 _ => bug!(),
34 };
35 Some(const_value)
36 }
37 _ => None,
38 }
39}
40
41pub fn composite_construct(types: &FxHashMap<Word, Instruction>, function: &mut Function) {
44 let defs = function
45 .all_inst_iter()
46 .filter_map(|inst| Some((inst.result_id?, inst.clone())))
47 .collect::<FxHashMap<Word, Instruction>>();
48 for block in &mut function.blocks {
49 for inst in &mut block.instructions {
50 if inst.class.opcode != Op::CompositeInsert {
51 continue;
52 }
53 let component_count = match composite_count(types, inst.result_type.unwrap()) {
55 Some(c) => c,
56 None => continue,
57 };
58 let mut components = vec![None; component_count];
62 let mut cur_inst: &Instruction = inst;
63 while cur_inst.class.opcode == Op::CompositeInsert {
65 if cur_inst.operands.len() != 3 {
66 break;
68 }
69 let value = cur_inst.operands[0].unwrap_id_ref();
70 let index = cur_inst.operands[2].unwrap_literal_bit32() as usize;
71 if index >= components.len() {
72 break;
75 }
76 if components[index].is_none() {
77 components[index] = Some(value);
78 }
79 cur_inst = match defs.get(&cur_inst.operands[1].unwrap_id_ref()) {
81 Some(i) => i,
82 None => break,
83 };
84 }
85 if let Some(composite_construct_operands) = components
88 .into_iter()
89 .map(|v| v.map(Operand::IdRef))
90 .collect::<Option<Vec<_>>>()
91 {
92 *inst = Instruction::new(
95 Op::CompositeConstruct,
96 inst.result_type,
97 inst.result_id,
98 composite_construct_operands,
99 );
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
106enum IdentifiedOperand {
107 Vector(Word),
109 Scalars(Vec<Word>),
112 NonValue(Operand),
115}
116
117fn get_composite_and_index(
120 types: &FxHashMap<Word, Instruction>,
121 defs: &FxHashMap<Word, Instruction>,
122 id: Word,
123 vector_width: u32,
124) -> Option<(Word, u32)> {
125 let inst = defs.get(&id)?;
126 if inst.class.opcode != Op::CompositeExtract {
127 return None;
128 }
129 if inst.operands.len() != 2 {
130 return None;
132 }
133 let composite = inst.operands[0].unwrap_id_ref();
134 let index = inst.operands[1].unwrap_literal_bit32();
135
136 let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?;
137 let vector_def = types.get(&composite_def.result_type.unwrap())?;
138
139 if vector_def.class.opcode != Op::TypeVector
143 || vector_width != vector_def.operands[1].unwrap_literal_bit32()
144 {
145 return None;
146 }
147
148 Some((composite, index))
149}
150
151fn match_vector_operand(
155 types: &FxHashMap<Word, Instruction>,
156 defs: &FxHashMap<Word, Instruction>,
157 results: &[&Instruction],
158 operand_index: usize,
159 vector_width: u32,
160) -> Option<Word> {
161 let operand_zero = match results[0].operands[operand_index] {
162 Operand::IdRef(id) => id,
163 _ => {
164 return None;
165 }
166 };
167 let composite_zero = match get_composite_and_index(types, defs, operand_zero, vector_width) {
169 Some((composite_zero, 0)) => composite_zero,
170 _ => {
171 return None;
172 }
173 };
174 for (expected_index, result) in results.iter().enumerate().skip(1) {
176 let operand = match result.operands[operand_index] {
177 Operand::IdRef(id) => id,
178 _ => {
179 return None;
180 }
181 };
182 let (composite, actual_index) =
183 match get_composite_and_index(types, defs, operand, vector_width) {
184 Some(x) => x,
185 None => {
186 return None;
187 }
188 };
189 if composite != composite_zero || expected_index != actual_index as usize {
192 return None;
193 }
194 }
195 Some(composite_zero)
196}
197
198fn match_vector_or_scalars_operand(
202 types: &FxHashMap<Word, Instruction>,
203 defs: &FxHashMap<Word, Instruction>,
204 results: &[&Instruction],
205 operand_index: usize,
206 vector_width: u32,
207) -> Option<IdentifiedOperand> {
208 if let Some(composite) = match_vector_operand(types, defs, results, operand_index, vector_width)
209 {
210 Some(IdentifiedOperand::Vector(composite))
211 } else {
212 let operands = results
213 .iter()
214 .map(|inst| match inst.operands[operand_index] {
215 Operand::IdRef(id) => Some(id),
216 _ => None,
217 })
218 .collect::<Option<Vec<_>>>()?;
219 Some(IdentifiedOperand::Scalars(operands))
220 }
221}
222
223fn match_all_same_operand(results: &[&Instruction], operand_index: usize) -> Option<Operand> {
226 let operand_zero = &results[0].operands[operand_index];
227 if results
228 .iter()
229 .skip(1)
230 .all(|inst| &inst.operands[operand_index] == operand_zero)
231 {
232 Some(operand_zero.clone())
233 } else {
234 None
235 }
236}
237
238fn match_operands(
241 types: &FxHashMap<Word, Instruction>,
242 defs: &FxHashMap<Word, Instruction>,
243 results: &[&Instruction],
244 vector_width: u32,
245) -> Option<Vec<IdentifiedOperand>> {
246 let operation_opcode = results[0].class.opcode;
247 if results.iter().skip(1).any(|r| {
249 r.class.opcode != operation_opcode || r.operands.len() != results[0].operands.len()
250 }) {
251 return None;
252 }
253 match operation_opcode {
255 Op::IAdd
256 | Op::FAdd
257 | Op::ISub
258 | Op::FSub
259 | Op::IMul
260 | Op::FMul
261 | Op::UDiv
262 | Op::SDiv
263 | Op::FDiv
264 | Op::UMod
265 | Op::SRem
266 | Op::FRem
267 | Op::FMod
268 | Op::ShiftRightLogical
269 | Op::ShiftRightArithmetic
270 | Op::ShiftLeftLogical
271 | Op::BitwiseOr
272 | Op::BitwiseXor
273 | Op::BitwiseAnd => {
274 let left = match_vector_or_scalars_operand(types, defs, results, 0, vector_width)?;
275 let right = match_vector_or_scalars_operand(types, defs, results, 1, vector_width)?;
276 match (left, right) {
277 (IdentifiedOperand::Scalars(_), IdentifiedOperand::Scalars(_)) => None,
279 (left, right) => Some(vec![left, right]),
280 }
281 }
282 Op::SNegate | Op::FNegate | Op::Not | Op::BitReverse => {
283 let value = match_vector_operand(types, defs, results, 0, vector_width)?;
284 Some(vec![IdentifiedOperand::Vector(value)])
285 }
286 Op::ExtInst => {
287 let set = match_all_same_operand(results, 0)?;
288 let instruction = match_all_same_operand(results, 1)?;
289 let parameters = (2..results[0].operands.len())
290 .map(|i| match_vector_or_scalars_operand(types, defs, results, i, vector_width));
291 let operands = IntoIterator::into_iter([
293 Some(IdentifiedOperand::NonValue(set)),
294 Some(IdentifiedOperand::NonValue(instruction)),
295 ])
296 .chain(parameters)
297 .collect::<Option<Vec<_>>>()?;
298 if operands
299 .iter()
300 .skip(2)
301 .all(|p| matches!(p, &IdentifiedOperand::Scalars(_)))
302 {
303 return None;
305 }
306 Some(operands)
307 }
308 _ => None,
309 }
310}
311
312fn process_instruction(
313 header: &mut ModuleHeader,
314 types: &FxHashMap<Word, Instruction>,
315 defs: &FxHashMap<Word, Instruction>,
316 instructions: &mut Vec<Instruction>,
317 instruction_index: &mut usize,
318) -> Option<Instruction> {
319 let inst = &instructions[*instruction_index];
320 if inst.class.opcode != Op::CompositeConstruct {
322 return None;
323 }
324 let inst_result_id = inst.result_id.unwrap();
325 let vector_ty = inst.result_type.unwrap();
326 let vector_ty_inst = match types.get(&vector_ty) {
327 Some(inst) => inst,
328 _ => return None,
329 };
330 if vector_ty_inst.class.opcode != Op::TypeVector {
331 return None;
332 }
333 let vector_width = vector_ty_inst.operands[1].unwrap_literal_bit32();
334 let results = inst
336 .operands
337 .iter()
338 .map(|op| defs.get(&op.unwrap_id_ref()))
339 .collect::<Option<Vec<_>>>()?;
340
341 let operation_opcode = results[0].class.opcode;
342 let composite_arguments = match_operands(types, defs, &results, vector_width)?;
344
345 if operation_opcode == Op::FMul
348 && composite_arguments.len() == 2
349 && let (&IdentifiedOperand::Vector(composite), IdentifiedOperand::Scalars(scalars))
350 | (IdentifiedOperand::Scalars(scalars), &IdentifiedOperand::Vector(composite)) =
351 (&composite_arguments[0], &composite_arguments[1])
352 {
353 let scalar = scalars[0];
354 if scalars.iter().skip(1).all(|&s| s == scalar) {
355 return Some(Instruction::new(
356 Op::VectorTimesScalar,
357 inst.result_type,
358 inst.result_id,
359 vec![Operand::IdRef(composite), Operand::IdRef(scalar)],
360 ));
361 }
362 }
363
364 let operands = composite_arguments
367 .into_iter()
368 .map(|operand| match operand {
369 IdentifiedOperand::Vector(composite) => Operand::IdRef(composite),
370 IdentifiedOperand::NonValue(operand) => operand,
371 IdentifiedOperand::Scalars(scalars) => {
372 let id = super::id(header);
373 instructions.insert(
376 *instruction_index,
377 Instruction::new(
378 Op::CompositeConstruct,
379 Some(vector_ty),
380 Some(id),
381 scalars.into_iter().map(Operand::IdRef).collect(),
382 ),
383 );
384 *instruction_index += 1;
385 Operand::IdRef(id)
386 }
387 })
388 .collect();
389
390 Some(Instruction::new(
391 operation_opcode,
392 Some(vector_ty),
393 Some(inst_result_id),
394 operands,
395 ))
396}
397
398pub fn vector_ops(
415 header: &mut ModuleHeader,
416 types: &FxHashMap<Word, Instruction>,
417 function: &mut Function,
418) {
419 let defs = function
420 .all_inst_iter()
421 .filter_map(|inst| Some((inst.result_id?, inst.clone())))
422 .collect::<FxHashMap<Word, Instruction>>();
423 for block in &mut function.blocks {
424 let mut instruction_index = 0;
430 while instruction_index < block.instructions.len() {
431 if let Some(result) = process_instruction(
432 header,
433 types,
434 &defs,
435 &mut block.instructions,
436 &mut instruction_index,
437 ) {
438 block.instructions[instruction_index] = result;
441 }
442
443 instruction_index += 1;
444 }
445 }
446}
447
448fn can_fuse_bool(
449 types: &FxHashMap<Word, Instruction>,
450 defs: &FxHashMap<Word, (usize, Instruction)>,
451 inst: &Instruction,
452) -> bool {
453 fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
454 let inst = types.get(&val)?;
455 if inst.class.opcode != Op::Constant {
456 return None;
457 }
458 match inst.operands[0] {
459 Operand::LiteralBit32(v) => Some(v),
460 _ => None,
461 }
462 }
463
464 fn visit(
465 types: &FxHashMap<Word, Instruction>,
466 defs: &FxHashMap<Word, (usize, Instruction)>,
467 visited: &mut FxHashSet<Word>,
468 value: Word,
469 ) -> bool {
470 if visited.insert(value) {
471 let inst = match defs.get(&value) {
472 Some((_, inst)) => inst,
473 None => return false,
474 };
475 match inst.class.opcode {
476 Op::Select => {
477 constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
478 && constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
479 }
480 Op::Phi => inst
481 .operands
482 .iter()
483 .step_by(2)
484 .all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
485 _ => false,
486 }
487 } else {
488 true
489 }
490 }
491
492 if inst.class.opcode != Op::INotEqual
493 || constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
494 {
495 return false;
496 }
497 let int_value = inst.operands[0].unwrap_id_ref();
498
499 visit(types, defs, &mut FxHashSet::default(), int_value)
500}
501
502fn fuse_bool(
503 header: &mut ModuleHeader,
504 defs: &FxHashMap<Word, (usize, Instruction)>,
505 phis_to_insert: &mut Vec<(usize, Instruction)>,
506 already_mapped: &mut FxHashMap<Word, Word>,
507 bool_ty: Word,
508 int_value: Word,
509) -> Word {
510 if let Some(&result) = already_mapped.get(&int_value) {
511 return result;
512 }
513 let (block_of_inst, inst) = defs.get(&int_value).unwrap();
514 match inst.class.opcode {
515 Op::Select => inst.operands[0].unwrap_id_ref(),
516 Op::Phi => {
517 let result_id = id(header);
518 already_mapped.insert(int_value, result_id);
519 let new_phi_args = inst
520 .operands
521 .chunks(2)
522 .flat_map(|arr| {
523 let phi_value = &arr[0];
524 let block = &arr[1];
525 [
526 Operand::IdRef(fuse_bool(
527 header,
528 defs,
529 phis_to_insert,
530 already_mapped,
531 bool_ty,
532 phi_value.unwrap_id_ref(),
533 )),
534 block.clone(),
535 ]
536 })
537 .collect::<Vec<_>>();
538 let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
539 phis_to_insert.push((*block_of_inst, inst));
540 result_id
541 }
542 _ => bug!("can_fuse_bool should have prevented this case"),
543 }
544}
545
546pub fn bool_fusion(
570 header: &mut ModuleHeader,
571 types: &FxHashMap<Word, Instruction>,
572 function: &mut Function,
573) {
574 let defs: FxHashMap<Word, (usize, Instruction)> = function
575 .blocks
576 .iter()
577 .enumerate()
578 .flat_map(|(block_id, block)| {
579 block
580 .instructions
581 .iter()
582 .filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
583 })
584 .collect();
585 let mut rewrite_rules = FxHashMap::default();
586 let mut phis_to_insert = Default::default();
587 let mut already_mapped = Default::default();
588 for block in &mut function.blocks {
589 for inst in &mut block.instructions {
590 if can_fuse_bool(types, &defs, inst) {
591 let rewrite_to = fuse_bool(
592 header,
593 &defs,
594 &mut phis_to_insert,
595 &mut already_mapped,
596 inst.result_type.unwrap(),
597 inst.operands[0].unwrap_id_ref(),
598 );
599 rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
600 *inst = Instruction::new(Op::Nop, None, None, Vec::new());
601 }
602 }
603 }
604 for (block, phi) in phis_to_insert {
605 function.blocks[block].instructions.insert(0, phi);
606 }
607 super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
608}