1pub(crate) mod controlflow;
4pub(crate) mod debuginfo;
5pub(crate) mod diagnostics;
6pub(crate) mod explicit_layout;
7mod fuse_selects;
8mod reduce;
9pub(crate) mod validate;
10
11use lazy_static::lazy_static;
12use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
13use spirt::func_at::FuncAt;
14use spirt::transform::InnerInPlaceTransform;
15use spirt::visit::{InnerVisit, Visitor};
16use spirt::{
17 AttrSet, Const, Context, ControlNode, ControlNodeKind, ControlRegion, DataInstDef,
18 DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityOrientedDenseMap, Func,
19 FuncDefBody, GlobalVar, Module, Type, Value, spv,
20};
21use std::collections::VecDeque;
22use std::iter;
23
24macro_rules! def_spv_spec_with_extra_well_known {
26 ($($group:ident: $ty:ty = [$($entry:ident),+ $(,)?]),+ $(,)?) => {
27 struct SpvSpecWithExtras {
28 __base_spec: &'static spv::spec::Spec,
29
30 well_known: SpvWellKnownWithExtras,
31 }
32
33 #[allow(non_snake_case)]
34 pub struct SpvWellKnownWithExtras {
35 __base_well_known: &'static spv::spec::WellKnown,
36
37 $($(pub $entry: $ty,)+)+
38 }
39
40 impl std::ops::Deref for SpvSpecWithExtras {
41 type Target = spv::spec::Spec;
42 fn deref(&self) -> &Self::Target {
43 self.__base_spec
44 }
45 }
46
47 impl std::ops::Deref for SpvWellKnownWithExtras {
48 type Target = spv::spec::WellKnown;
49 fn deref(&self) -> &Self::Target {
50 self.__base_well_known
51 }
52 }
53
54 impl SpvSpecWithExtras {
55 #[inline(always)]
56 #[must_use]
57 pub fn get() -> &'static SpvSpecWithExtras {
58 lazy_static! {
59 static ref SPEC: SpvSpecWithExtras = {
60 #[allow(non_camel_case_types)]
61 struct PerWellKnownGroup<$($group),+> {
62 $($group: $group),+
63 }
64
65 let spv_spec = spv::spec::Spec::get();
66 let wk = &spv_spec.well_known;
67
68 let [decorations, storage_classes] = [wk.Decoration, wk.StorageClass].map(|kind| match kind.def() {
69 spv::spec::OperandKindDef::ValueEnum { variants } => variants,
70 _ => unreachable!(),
71 });
72
73 let lookup_fns = PerWellKnownGroup {
74 opcode: |name| spv_spec.instructions.lookup(name).unwrap(),
75 operand_kind: |name| spv_spec.operand_kinds.lookup(name).unwrap(),
76 decoration: |name| decorations.lookup(name).unwrap().into(),
77 storage_class: |name| storage_classes.lookup(name).unwrap().into(),
78 };
79
80 SpvSpecWithExtras {
81 __base_spec: spv_spec,
82
83 well_known: SpvWellKnownWithExtras {
84 __base_well_known: &spv_spec.well_known,
85
86 $($($entry: (lookup_fns.$group)(stringify!($entry)),)+)+
87 },
88 }
89 };
90 }
91 &SPEC
92 }
93 }
94 };
95}
96def_spv_spec_with_extra_well_known! {
97 opcode: spv::spec::Opcode = [
98 OpTypeVoid,
99
100 OpConstantComposite,
101
102 OpBitcast,
103 OpCompositeInsert,
104 OpCompositeExtract,
105 OpCompositeConstruct,
106
107 OpCopyMemory,
108 ],
109 operand_kind: spv::spec::OperandKind = [
110 Capability,
111 ExecutionModel,
112 ImageFormat,
113 MemoryAccess,
114 ],
115 decoration: u32 = [
116 UserTypeGOOGLE,
117 MatrixStride,
118 ],
119 storage_class: u32 = [
120 PushConstant,
121 Uniform,
122 StorageBuffer,
123 PhysicalStorageBuffer,
124 ],
125}
126
127pub(super) fn run_func_passes<P>(
132 module: &mut Module,
133 passes: &[impl AsRef<str>],
134 mut before_pass: impl FnMut(&'static str, &Module) -> P,
136 mut after_pass: impl FnMut(Option<&Module>, P),
137) {
138 let cx = &module.cx();
139
140 let all_funcs = {
142 let mut collector = ReachableUseCollector {
143 cx,
144 module,
145
146 seen_types: FxIndexSet::default(),
147 seen_consts: FxIndexSet::default(),
148 seen_data_inst_forms: FxIndexSet::default(),
149 seen_global_vars: FxIndexSet::default(),
150 seen_funcs: FxIndexSet::default(),
151 };
152 for (export_key, &exportee) in &module.exports {
153 export_key.inner_visit_with(&mut collector);
154 exportee.inner_visit_with(&mut collector);
155 }
156 collector.seen_funcs
157 };
158
159 for name in passes {
160 let name = name.as_ref();
161
162 if name == "qptr" {
164 let layout_config = &spirt::qptr::LayoutConfig {
165 abstract_bool_size_align: (1, 1),
166 logical_ptr_size_align: (4, 4),
167 ..spirt::qptr::LayoutConfig::VULKAN_SCALAR_LAYOUT
168 };
169
170 let profiler = before_pass("qptr::lower_from_spv_ptrs", module);
171 spirt::passes::qptr::lower_from_spv_ptrs(module, layout_config);
172 after_pass(Some(module), profiler);
173
174 let profiler = before_pass("qptr::analyze_uses", module);
175 spirt::passes::qptr::analyze_uses(module, layout_config);
176 after_pass(Some(module), profiler);
177
178 let profiler = before_pass("qptr::lift_to_spv_ptrs", module);
179 spirt::passes::qptr::lift_to_spv_ptrs(module, layout_config);
180 after_pass(Some(module), profiler);
181
182 continue;
183 }
184
185 let (full_name, pass_fn): (_, fn(_, &mut _)) = match name {
186 "reduce" => ("spirt_passes::reduce", reduce::reduce_in_func),
187 "fuse_selects" => (
188 "spirt_passes::fuse_selects",
189 fuse_selects::fuse_selects_in_func,
190 ),
191 _ => panic!("unknown `--spirt-passes={name}`"),
192 };
193
194 let profiler = before_pass(full_name, module);
195 for &func in &all_funcs {
196 if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def {
197 pass_fn(cx, func_def_body);
198
199 remove_unused_values_in_func(cx, func_def_body);
201 }
202 }
203 after_pass(Some(module), profiler);
204 }
205}
206
207struct ReachableUseCollector<'a> {
209 cx: &'a Context,
210 module: &'a Module,
211
212 seen_types: FxIndexSet<Type>,
214 seen_consts: FxIndexSet<Const>,
215 seen_data_inst_forms: FxIndexSet<DataInstForm>,
216 seen_global_vars: FxIndexSet<GlobalVar>,
217 seen_funcs: FxIndexSet<Func>,
218}
219
220impl Visitor<'_> for ReachableUseCollector<'_> {
221 fn visit_attr_set_use(&mut self, _attrs: AttrSet) {}
223 fn visit_type_use(&mut self, ty: Type) {
224 if self.seen_types.insert(ty) {
225 self.visit_type_def(&self.cx[ty]);
226 }
227 }
228 fn visit_const_use(&mut self, ct: Const) {
229 if self.seen_consts.insert(ct) {
230 self.visit_const_def(&self.cx[ct]);
231 }
232 }
233 fn visit_data_inst_form_use(&mut self, data_inst_form: DataInstForm) {
234 if self.seen_data_inst_forms.insert(data_inst_form) {
235 self.visit_data_inst_form_def(&self.cx[data_inst_form]);
236 }
237 }
238
239 fn visit_global_var_use(&mut self, gv: GlobalVar) {
240 if self.seen_global_vars.insert(gv) {
241 self.visit_global_var_decl(&self.module.global_vars[gv]);
242 }
243 }
244 fn visit_func_use(&mut self, func: Func) {
245 if self.seen_funcs.insert(func) {
246 self.visit_func_decl(&self.module.funcs[func]);
247 }
248 }
249}
250
251struct VisitAllControlRegionsAndNodes<S, VCR, VCN> {
253 state: S,
254 visit_control_region: VCR,
255 visit_control_node: VCN,
256}
257const _: () = {
258 use spirt::{func_at::*, visit::*, *};
259
260 impl<
261 'a,
262 S,
263 VCR: FnMut(&mut S, FuncAt<'a, ControlRegion>),
264 VCN: FnMut(&mut S, FuncAt<'a, ControlNode>),
265 > Visitor<'a> for VisitAllControlRegionsAndNodes<S, VCR, VCN>
266 {
267 fn visit_attr_set_use(&mut self, _: AttrSet) {}
270 fn visit_type_use(&mut self, _: Type) {}
271 fn visit_const_use(&mut self, _: Const) {}
272 fn visit_data_inst_form_use(&mut self, _: DataInstForm) {}
273 fn visit_global_var_use(&mut self, _: GlobalVar) {}
274 fn visit_func_use(&mut self, _: Func) {}
275
276 fn visit_control_region_def(&mut self, func_at_control_region: FuncAt<'a, ControlRegion>) {
277 (self.visit_control_region)(&mut self.state, func_at_control_region);
278 func_at_control_region.inner_visit_with(self);
279 }
280 fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'a, ControlNode>) {
281 (self.visit_control_node)(&mut self.state, func_at_control_node);
282 func_at_control_node.inner_visit_with(self);
283 }
284 }
285};
286
287struct ReplaceValueWith<F>(F);
289const _: () = {
290 use spirt::{transform::*, *};
291
292 impl<F: FnMut(Value) -> Option<Value>> Transformer for ReplaceValueWith<F> {
293 fn transform_value_use(&mut self, v: &Value) -> Transformed<Value> {
294 self.0(*v).map_or(Transformed::Unchanged, Transformed::Changed)
295 }
296 }
297};
298
299fn remove_unused_values_in_func(cx: &Context, func_def_body: &mut FuncDefBody) {
304 if func_def_body.unstructured_cfg.is_some() {
306 return;
307 }
308
309 let wk = &SpvSpecWithExtras::get().well_known;
310
311 struct Propagator {
312 func_body_region: ControlRegion,
313
314 loop_body_to_loop: EntityOrientedDenseMap<ControlRegion, ControlNode>,
316
317 used: FxHashSet<Value>,
320
321 queue: VecDeque<Value>,
322 }
323 impl Propagator {
324 fn mark_used(&mut self, v: Value) {
325 if let Value::Const(_) = v {
326 return;
327 }
328 if let Value::ControlRegionInput {
329 region,
330 input_idx: _,
331 } = v
332 && region == self.func_body_region
333 {
334 return;
335 }
336 if self.used.insert(v) {
337 self.queue.push_back(v);
338 }
339 }
340 fn propagate_used(&mut self, func: FuncAt<'_, ()>) {
341 while let Some(v) = self.queue.pop_front() {
342 match v {
343 Value::Const(_) => unreachable!(),
344 Value::ControlRegionInput { region, input_idx } => {
345 let loop_node = self.loop_body_to_loop[region];
346 let initial_inputs = match &func.at(loop_node).def().kind {
347 ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs,
348 _ => unreachable!(),
350 };
351 self.mark_used(initial_inputs[input_idx as usize]);
352 self.mark_used(func.at(region).def().outputs[input_idx as usize]);
353 }
354 Value::ControlNodeOutput {
355 control_node,
356 output_idx,
357 } => {
358 let cases = match &func.at(control_node).def().kind {
359 ControlNodeKind::Select { cases, .. } => cases,
360 _ => unreachable!(),
362 };
363 for &case in cases {
364 self.mark_used(func.at(case).def().outputs[output_idx as usize]);
365 }
366 }
367 Value::DataInstOutput(inst) => {
368 for &input in &func.at(inst).def().inputs {
369 self.mark_used(input);
370 }
371 }
372 }
373 }
374 }
375 }
376
377 let propagator = {
380 let mut visitor = VisitAllControlRegionsAndNodes {
381 state: Propagator {
382 func_body_region: func_def_body.body,
383 loop_body_to_loop: Default::default(),
384 used: Default::default(),
385 queue: Default::default(),
386 },
387 visit_control_region: |_: &mut _, _| {},
388 visit_control_node:
389 |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
390 if let ControlNodeKind::Loop { body, .. } = func_at_control_node.def().kind {
391 propagator
392 .loop_body_to_loop
393 .insert(body, func_at_control_node.position);
394 }
395 },
396 };
397 func_def_body.inner_visit_with(&mut visitor);
398 visitor.state
399 };
400
401 let mut all_control_nodes = vec![];
403
404 let used_values = {
405 let mut visitor = VisitAllControlRegionsAndNodes {
406 state: propagator,
407 visit_control_region: |_: &mut _, _| {},
408 visit_control_node:
409 |propagator: &mut Propagator, func_at_control_node: FuncAt<'_, ControlNode>| {
410 all_control_nodes.push(func_at_control_node.position);
411
412 let mut mark_used_and_propagate = |v| {
413 propagator.mark_used(v);
414 propagator.propagate_used(func_at_control_node.at(()));
415 };
416 match &func_at_control_node.def().kind {
417 &ControlNodeKind::Block { insts } => {
418 for func_at_inst in func_at_control_node.at(insts) {
419 if let DataInstKind::SpvInst(spv_inst) =
422 &cx[func_at_inst.def().form].kind
423 {
424 if [wk.OpNop, wk.OpCompositeInsert].contains(&spv_inst.opcode) {
427 continue;
428 }
429 }
430 mark_used_and_propagate(Value::DataInstOutput(
431 func_at_inst.position,
432 ));
433 }
434 }
435
436 &ControlNodeKind::Select { scrutinee: v, .. }
437 | &ControlNodeKind::Loop {
438 repeat_condition: v,
439 ..
440 } => mark_used_and_propagate(v),
441
442 ControlNodeKind::ExitInvocation {
443 kind: spirt::cfg::ExitInvocationKind::SpvInst(_),
444 inputs,
445 } => {
446 for &v in inputs {
447 mark_used_and_propagate(v);
448 }
449 }
450 }
451 },
452 };
453 func_def_body.inner_visit_with(&mut visitor);
454
455 let mut propagator = visitor.state;
456 for &v in &func_def_body.at_body().def().outputs {
457 propagator.mark_used(v);
458 propagator.propagate_used(func_def_body.at(()));
459 }
460
461 assert!(propagator.queue.is_empty());
462 propagator.used
463 };
464
465 let mut value_replacements = FxHashMap::default();
468
469 for control_node in all_control_nodes {
471 let control_node_def = func_def_body.at(control_node).def();
472 match &control_node_def.kind {
473 &ControlNodeKind::Block { insts } => {
474 let mut all_nops = true;
475 let mut func_at_inst_iter = func_def_body.at_mut(insts).into_iter();
476 while let Some(mut func_at_inst) = func_at_inst_iter.next() {
477 if let DataInstKind::SpvInst(spv_inst) =
478 &cx[func_at_inst.reborrow().def().form].kind
479 && spv_inst.opcode == wk.OpNop
480 {
481 continue;
482 }
483 if !used_values.contains(&Value::DataInstOutput(func_at_inst.position)) {
484 *func_at_inst.def() = DataInstDef {
489 attrs: Default::default(),
490 form: cx.intern(DataInstFormDef {
491 kind: DataInstKind::SpvInst(wk.OpNop.into()),
492 output_type: None,
493 }),
494 inputs: iter::empty().collect(),
495 };
496 continue;
497 }
498 all_nops = false;
499 }
500 if all_nops {
503 func_def_body.at_mut(control_node).def().kind = ControlNodeKind::Block {
504 insts: Default::default(),
505 };
506 }
507 }
508
509 ControlNodeKind::Select { cases, .. } => {
510 let cases = cases.clone();
512
513 let mut new_idx = 0;
514 for original_idx in 0..control_node_def.outputs.len() {
515 let original_output = Value::ControlNodeOutput {
516 control_node,
517 output_idx: original_idx as u32,
518 };
519
520 if !used_values.contains(&original_output) {
521 func_def_body
523 .at_mut(control_node)
524 .def()
525 .outputs
526 .remove(new_idx);
527 for &case in &cases {
528 func_def_body.at_mut(case).def().outputs.remove(new_idx);
529 }
530 continue;
531 }
532
533 if original_idx != new_idx {
535 let new_output = Value::ControlNodeOutput {
536 control_node,
537 output_idx: new_idx as u32,
538 };
539 value_replacements.insert(original_output, new_output);
540 }
541 new_idx += 1;
542 }
543 }
544
545 ControlNodeKind::Loop {
546 body,
547 initial_inputs,
548 ..
549 } => {
550 let body = *body;
551
552 let mut new_idx = 0;
553 for original_idx in 0..initial_inputs.len() {
554 let original_input = Value::ControlRegionInput {
555 region: body,
556 input_idx: original_idx as u32,
557 };
558
559 if !used_values.contains(&original_input) {
560 match &mut func_def_body.at_mut(control_node).def().kind {
562 ControlNodeKind::Loop { initial_inputs, .. } => {
563 initial_inputs.remove(new_idx);
564 }
565 _ => unreachable!(),
566 }
567 let body_def = func_def_body.at_mut(body).def();
568 body_def.inputs.remove(new_idx);
569 body_def.outputs.remove(new_idx);
570 continue;
571 }
572
573 if original_idx != new_idx {
575 let new_input = Value::ControlRegionInput {
576 region: body,
577 input_idx: new_idx as u32,
578 };
579 value_replacements.insert(original_input, new_input);
580 }
581 new_idx += 1;
582 }
583 }
584
585 ControlNodeKind::ExitInvocation { .. } => {}
586 }
587 }
588
589 if !value_replacements.is_empty() {
590 func_def_body.inner_in_place_transform_with(&mut ReplaceValueWith(|v| match v {
591 Value::Const(_) => None,
592 _ => value_replacements.get(&v).copied(),
593 }));
594 }
595}