1pub(crate) mod controlflow;
4pub(crate) mod debuginfo;
5pub(crate) mod diagnostics;
6mod fuse_selects;
7mod reduce;
8pub(crate) mod validate;
9
10use lazy_static::lazy_static;
11use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
12use spirt::func_at::FuncAt;
13use spirt::transform::InnerInPlaceTransform;
14use spirt::visit::{InnerVisit, Visitor};
15use spirt::{
16 AttrSet, Const, Context, ControlNode, ControlNodeKind, ControlRegion, DataInstDef,
17 DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityOrientedDenseMap, Func,
18 FuncDefBody, GlobalVar, Module, Type, Value, spv,
19};
20use std::collections::VecDeque;
21use std::iter;
22
23macro_rules! def_spv_spec_with_extra_well_known {
25 ($($group:ident: $ty:ty = [$($entry:ident),+ $(,)?]),+ $(,)?) => {
26 struct SpvSpecWithExtras {
27 __base_spec: &'static spv::spec::Spec,
28
29 well_known: SpvWellKnownWithExtras,
30 }
31
32 #[allow(non_snake_case)]
33 pub struct SpvWellKnownWithExtras {
34 __base_well_known: &'static spv::spec::WellKnown,
35
36 $($(pub $entry: $ty,)+)+
37 }
38
39 impl std::ops::Deref for SpvSpecWithExtras {
40 type Target = spv::spec::Spec;
41 fn deref(&self) -> &Self::Target {
42 self.__base_spec
43 }
44 }
45
46 impl std::ops::Deref for SpvWellKnownWithExtras {
47 type Target = spv::spec::WellKnown;
48 fn deref(&self) -> &Self::Target {
49 self.__base_well_known
50 }
51 }
52
53 impl SpvSpecWithExtras {
54 #[inline(always)]
55 #[must_use]
56 pub fn get() -> &'static SpvSpecWithExtras {
57 lazy_static! {
58 static ref SPEC: SpvSpecWithExtras = {
59 #[allow(non_camel_case_types)]
60 struct PerWellKnownGroup<$($group),+> {
61 $($group: $group),+
62 }
63
64 let spv_spec = spv::spec::Spec::get();
65 let wk = &spv_spec.well_known;
66
67 let decorations = match wk.Decoration.def() {
68 spv::spec::OperandKindDef::ValueEnum { variants } => variants,
69 _ => unreachable!(),
70 };
71
72 let lookup_fns = PerWellKnownGroup {
73 opcode: |name| spv_spec.instructions.lookup(name).unwrap(),
74 operand_kind: |name| spv_spec.operand_kinds.lookup(name).unwrap(),
75 decoration: |name| decorations.lookup(name).unwrap().into(),
76 };
77
78 SpvSpecWithExtras {
79 __base_spec: spv_spec,
80
81 well_known: SpvWellKnownWithExtras {
82 __base_well_known: &spv_spec.well_known,
83
84 $($($entry: (lookup_fns.$group)(stringify!($entry)),)+)+
85 },
86 }
87 };
88 }
89 &SPEC
90 }
91 }
92 };
93}
94def_spv_spec_with_extra_well_known! {
95 opcode: spv::spec::Opcode = [
96 OpTypeVoid,
97
98 OpConstantComposite,
99
100 OpBitcast,
101 OpCompositeInsert,
102 OpCompositeExtract,
103 ],
104 operand_kind: spv::spec::OperandKind = [
105 Capability,
106 ExecutionModel,
107 ImageFormat,
108 ],
109 decoration: u32 = [
110 UserTypeGOOGLE,
111 ],
112}
113
114pub(super) fn run_func_passes<P>(
119 module: &mut Module,
120 passes: &[impl AsRef<str>],
121 mut before_pass: impl FnMut(&'static str, &Module) -> P,
123 mut after_pass: impl FnMut(Option<&Module>, P),
124) {
125 let cx = &module.cx();
126
127 let all_funcs = {
129 let mut collector = ReachableUseCollector {
130 cx,
131 module,
132
133 seen_types: FxIndexSet::default(),
134 seen_consts: FxIndexSet::default(),
135 seen_data_inst_forms: FxIndexSet::default(),
136 seen_global_vars: FxIndexSet::default(),
137 seen_funcs: FxIndexSet::default(),
138 };
139 for (export_key, &exportee) in &module.exports {
140 export_key.inner_visit_with(&mut collector);
141 exportee.inner_visit_with(&mut collector);
142 }
143 collector.seen_funcs
144 };
145
146 for name in passes {
147 let name = name.as_ref();
148
149 if name == "qptr" {
151 let layout_config = &spirt::qptr::LayoutConfig {
152 abstract_bool_size_align: (1, 1),
153 logical_ptr_size_align: (4, 4),
154 ..spirt::qptr::LayoutConfig::VULKAN_SCALAR_LAYOUT
155 };
156
157 let profiler = before_pass("qptr::lower_from_spv_ptrs", module);
158 spirt::passes::qptr::lower_from_spv_ptrs(module, layout_config);
159 after_pass(Some(module), profiler);
160
161 let profiler = before_pass("qptr::analyze_uses", module);
162 spirt::passes::qptr::analyze_uses(module, layout_config);
163 after_pass(Some(module), profiler);
164
165 let profiler = before_pass("qptr::lift_to_spv_ptrs", module);
166 spirt::passes::qptr::lift_to_spv_ptrs(module, layout_config);
167 after_pass(Some(module), profiler);
168
169 continue;
170 }
171
172 let (full_name, pass_fn): (_, fn(_, &mut _)) = match name {
173 "reduce" => ("spirt_passes::reduce", reduce::reduce_in_func),
174 "fuse_selects" => (
175 "spirt_passes::fuse_selects",
176 fuse_selects::fuse_selects_in_func,
177 ),
178 _ => panic!("unknown `--spirt-passes={name}`"),
179 };
180
181 let profiler = before_pass(full_name, module);
182 for &func in &all_funcs {
183 if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def {
184 pass_fn(cx, func_def_body);
185
186 remove_unused_values_in_func(cx, func_def_body);
188 }
189 }
190 after_pass(Some(module), profiler);
191 }
192}
193
194struct ReachableUseCollector<'a> {
196 cx: &'a Context,
197 module: &'a Module,
198
199 seen_types: FxIndexSet<Type>,
201 seen_consts: FxIndexSet<Const>,
202 seen_data_inst_forms: FxIndexSet<DataInstForm>,
203 seen_global_vars: FxIndexSet<GlobalVar>,
204 seen_funcs: FxIndexSet<Func>,
205}
206
207impl Visitor<'_> for ReachableUseCollector<'_> {
208 fn visit_attr_set_use(&mut self, _attrs: AttrSet) {}
210 fn visit_type_use(&mut self, ty: Type) {
211 if self.seen_types.insert(ty) {
212 self.visit_type_def(&self.cx[ty]);
213 }
214 }
215 fn visit_const_use(&mut self, ct: Const) {
216 if self.seen_consts.insert(ct) {
217 self.visit_const_def(&self.cx[ct]);
218 }
219 }
220 fn visit_data_inst_form_use(&mut self, data_inst_form: DataInstForm) {
221 if self.seen_data_inst_forms.insert(data_inst_form) {
222 self.visit_data_inst_form_def(&self.cx[data_inst_form]);
223 }
224 }
225
226 fn visit_global_var_use(&mut self, gv: GlobalVar) {
227 if self.seen_global_vars.insert(gv) {
228 self.visit_global_var_decl(&self.module.global_vars[gv]);
229 }
230 }
231 fn visit_func_use(&mut self, func: Func) {
232 if self.seen_funcs.insert(func) {
233 self.visit_func_decl(&self.module.funcs[func]);
234 }
235 }
236}
237
238struct VisitAllControlRegionsAndNodes<S, VCR, VCN> {
240 state: S,
241 visit_control_region: VCR,
242 visit_control_node: VCN,
243}
244const _: () = {
245 use spirt::{func_at::*, visit::*, *};
246
247 impl<
248 'a,
249 S,
250 VCR: FnMut(&mut S, FuncAt<'a, ControlRegion>),
251 VCN: FnMut(&mut S, FuncAt<'a, ControlNode>),
252 > Visitor<'a> for VisitAllControlRegionsAndNodes<S, VCR, VCN>
253 {
254 fn visit_attr_set_use(&mut self, _: AttrSet) {}
257 fn visit_type_use(&mut self, _: Type) {}
258 fn visit_const_use(&mut self, _: Const) {}
259 fn visit_data_inst_form_use(&mut self, _: DataInstForm) {}
260 fn visit_global_var_use(&mut self, _: GlobalVar) {}
261 fn visit_func_use(&mut self, _: Func) {}
262
263 fn visit_control_region_def(&mut self, func_at_control_region: FuncAt<'a, ControlRegion>) {
264 (self.visit_control_region)(&mut self.state, func_at_control_region);
265 func_at_control_region.inner_visit_with(self);
266 }
267 fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'a, ControlNode>) {
268 (self.visit_control_node)(&mut self.state, func_at_control_node);
269 func_at_control_node.inner_visit_with(self);
270 }
271 }
272};
273
274struct ReplaceValueWith<F>(F);
276const _: () = {
277 use spirt::{transform::*, *};
278
279 impl<F: FnMut(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
280 fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
281 self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
282 }
283 }
284};
285
286fn remove_unused_values_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
291 if func_def_body.unstructured_cfg.is_some() {
293 return;
294 }
295
296 let wk = &SpvSpecWithExtras::get().well_known;
297
298 struct Propagator {
299 func_body_region: ControlRegion,
300
301 loop_body_to_loop: EntityOrientedDenseMap<ControlRegion, ControlNode>,
303
304 used: FxHashSet<Value>,
307
308 queue: VecDeque<Value>,
309 }
310 impl Propagator {
311 fn mark_used(&mut self, v: Value) {
312 if let Value::Const(_) = v {
313 return;
314 }
315 if let Value::ControlRegionInput {
316 region,
317 input_idx: _,
318 } = v
319 && region == self.func_body_region
320 {
321 return;
322 }
323 if self.used.insert(v) {
324 self.queue.push_back(v);
325 }
326 }
327 fn propagate_used(&mut self, func: FuncAt<'_, ()>) {
328 while let Some(v) = self.queue.pop_front() {
329 match v {
330 Value::Const(_) => unreachable!(),
331 Value::ControlRegionInput { region, input_idx } => {
332 let loop_node = self.loop_body_to_loop[region];
333 let initial_inputs = match &func.at(loop_node).def().kind {
334 ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
335 _ => unreachable!(),
337 };
338 self.mark_used(initial_inputs[input_idx as usize]);
339 self.mark_used(func.at(region).def().outputs[input_idx as usize]);
340 }
341 Value::ControlNodeOutput {
342 control_node,
343 output_idx,
344 } => {
345 let cases = match &func.at(control_node).def().kind {
346 ControlNodeKind::Select { cases, .. } => cases,
347 _ => unreachable!(),
349 };
350 for &case in cases {
351 self.mark_used(func.at(case).def().outputs[output_idx as usize]);
352 }
353 }
354 Value::DataInstOutput(inst) => {
355 for &input in &func.at(inst).def().inputs {
356 self.mark_used(input);
357 }
358 }
359 }
360 }
361 }
362 }
363
364 let propagator = {
367 let mut visitor = VisitAllControlRegionsAndNodes {
368 state: Propagator {
369 func_body_region: func_def_body.body,
370 loop_body_to_loop: Default::default(),
371 used: Default::default(),
372 queue: Default::default(),
373 },
374 visit_control_region: |_: &mut _, _| {},
375 visit_control_node:
376 |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
377 if let ControlNodeKind::Loop { body, .. } = func_at_control_node.def().kind {
378 propagator
379 .loop_body_to_loop
380 .insert(body, func_at_control_node.position);
381 }
382 },
383 };
384 func_def_body.inner_visit_with(&mut visitor);
385 visitor.state
386 };
387
388 let mut all_control_nodes = vec![];
390
391 let used_values = {
392 let mut visitor = VisitAllControlRegionsAndNodes {
393 state: propagator,
394 visit_control_region: |_: &mut _, _| {},
395 visit_control_node:
396 |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
397 all_control_nodes.push(func_at_control_node.position);
398
399 let mut mark_used_and_propagate = |v| {
400 propagator.mark_used(v);
401 propagator.propagate_used(func_at_control_node.at(()));
402 };
403 match &func_at_control_node.def().kind {
404 &ControlNodeKind::Block { insts } => {
405 for func_at_inst in func_at_control_node.at(insts) {
406 if let DataInstKind::SpvInst(spv_inst) =
409 &cx[func_at_inst.def().form].kind
410 {
411 if [wk.OpNop, wk.OpCompositeInsert].contains(&spv_inst.opcode) {
414 continue;
415 }
416 }
417 mark_used_and_propagate(Value::DataInstOutput(
418 func_at_inst.position,
419 ));
420 }
421 }
422
423 &ControlNodeKind::Select { scrutinee: v, .. }
424 | &ControlNodeKind::Loop {
425 repeat_condition: v,
426 ..
427 } => mark_used_and_propagate(v),
428
429 ControlNodeKind::ExitInvocation {
430 kind: spirt::cfg::ExitInvocationKind::SpvInst(_),
431 inputs,
432 } => {
433 for &v in inputs {
434 mark_used_and_propagate(v);
435 }
436 }
437 }
438 },
439 };
440 func_def_body.inner_visit_with(&mut visitor);
441
442 let mut propagator = visitor.state;
443 for &v in &func_def_body.at_body().def().outputs {
444 propagator.mark_used(v);
445 propagator.propagate_used(func_def_body.at(()));
446 }
447
448 assert!(propagator.queue.is_empty());
449 propagator.used
450 };
451
452 let mut value_replacements = FxHashMap::default();
455
456 for control_node in all_control_nodes {
458 let control_node_def = func_def_body.at(control_node).def();
459 match &control_node_def.kind {
460 &ControlNodeKind::Block { insts } => {
461 let mut all_nops = true;
462 let mut func_at_inst_iter = func_def_body.at_mut(insts).into_iter();
463 while let Some(mut func_at_inst) = func_at_inst_iter.next() {
464 if let DataInstKind::SpvInst(spv_inst) =
465 &cx[func_at_inst.reborrow().def().form].kind
466 && spv_inst.opcode == wk.OpNop
467 {
468 continue;
469 }
470 if !used_values.contains(&Value::DataInstOutput(func_at_inst.position)) {
471 *func_at_inst.def() = DataInstDef {
476 attrs: Default::default(),
477 form: cx.intern(DataInstFormDef {
478 kind: DataInstKind::SpvInst(wk.OpNop.into()),
479 output_type: None,
480 }),
481 inputs: iter::empty().collect(),
482 };
483 continue;
484 }
485 all_nops = false;
486 }
487 if all_nops {
490 func_def_body.at_mut(control_node).def().kind = ControlNodeKind::Block {
491 insts: Default::default(),
492 };
493 }
494 }
495
496 ControlNodeKind::Select { cases, .. } => {
497 let cases = cases.clone();
499
500 let mut new_idx = 0;
501 for original_idx in 0..control_node_def.outputs.len() {
502 let original_output = Value::ControlNodeOutput {
503 control_node,
504 output_idx: original_idx as u32,
505 };
506
507 if !used_values.contains(&original_output) {
508 func_def_body
510 .at_mut(control_node)
511 .def()
512 .outputs
513 .remove(new_idx);
514 for &case in &cases {
515 func_def_body.at_mut(case).def().outputs.remove(new_idx);
516 }
517 continue;
518 }
519
520 if original_idx != new_idx {
522 let new_output = Value::ControlNodeOutput {
523 control_node,
524 output_idx: new_idx as u32,
525 };
526 value_replacements.insert(original_output, new_output);
527 }
528 new_idx += 1;
529 }
530 }
531
532 ControlNodeKind::Loop {
533 body,
534 initial_inputs,
535 ..
536 } => {
537 let body = *body;
538
539 let mut new_idx = 0;
540 for original_idx in 0..initial_inputs.len() {
541 let original_input = Value::ControlRegionInput {
542 region: body,
543 input_idx: original_idx as u32,
544 };
545
546 if !used_values.contains(&original_input) {
547 match &mut func_def_body.at_mut(control_node).def().kind {
549 ControlNodeKind::Loop { initial_inputs, .. } => {
550 initial_inputs.remove(new_idx);
551 }
552 _ => unreachable!(),
553 }
554 let body_def = func_def_body.at_mut(body).def();
555 body_def.inputs.remove(new_idx);
556 body_def.outputs.remove(new_idx);
557 continue;
558 }
559
560 if original_idx != new_idx {
562 let new_input = Value::ControlRegionInput {
563 region: body,
564 input_idx: new_idx as u32,
565 };
566 value_replacements.insert(original_input, new_input);
567 }
568 new_idx += 1;
569 }
570 }
571
572 ControlNodeKind::ExitInvocation { .. } => {}
573 }
574 }
575
576 if !value_replacements.is_empty() {
577 func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
578 Value::Const(_) => None,
579 _ => value_replacements.get(&v).copied(),
580 }));
581 }
582}