1use crate::linker::ipo::CallGraph;
53use crate::spirv_type_constraints::{self, InstSig, StorageClassPat, TyListPat, TyPat};
54use indexmap::{IndexMap, IndexSet};
55use rspirv::dr::{Builder, Function, Instruction, Module, Operand};
56use rspirv::spirv::{Op, StorageClass, Word};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet};
58use smallvec::SmallVec;
59use std::collections::{BTreeMap, VecDeque};
60use std::ops::{Range, RangeTo};
61use std::{fmt, io, iter, mem, slice};
62use tracing::{debug, error};
63
64struct FmtBy<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result>(F);
66
67impl<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result> fmt::Debug for FmtBy<F> {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 self.0(f)
70 }
71}
72
73impl<F: Fn(&mut fmt::Formatter<'_>) -> fmt::Result> fmt::Display for FmtBy<F> {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 self.0(f)
76 }
77}
78
79pub trait Specialization {
80 fn specialize_operand(&self, operand: &Operand) -> bool;
83
84 fn concrete_fallback(&self) -> Operand;
90}
91
92pub struct SimpleSpecialization<SO: Fn(&Operand) -> bool> {
95 pub specialize_operand: SO,
96 pub concrete_fallback: Operand,
97}
98
99impl<SO: Fn(&Operand) -> bool> Specialization for SimpleSpecialization<SO> {
100 fn specialize_operand(&self, operand: &Operand) -> bool {
101 (self.specialize_operand)(operand)
102 }
103 fn concrete_fallback(&self) -> Operand {
104 self.concrete_fallback.clone()
105 }
106}
107
108pub fn specialize(
109 opts: &super::Options,
110 module: Module,
111 specialization: impl Specialization,
112) -> Module {
113 let dump_instances = &opts.specializer_dump_instances;
114
115 let mut debug_names = FxHashMap::default();
116 if dump_instances.is_some() {
117 debug_names = module
118 .debug_names
119 .iter()
120 .filter(|inst| inst.class.opcode == Op::Name)
121 .map(|inst| {
122 (
123 inst.operands[0].unwrap_id_ref(),
124 inst.operands[1].unwrap_literal_string().to_string(),
125 )
126 })
127 .collect();
128 }
129
130 let mut specializer = Specializer {
131 specialization,
132 debug_names,
133 generics: IndexMap::new(),
134 int_consts: FxHashMap::default(),
135 };
136
137 specializer.collect_generics(&module);
138
139 let mut interface_concrete_instances = IndexSet::new();
145 for inst in &module.entry_points {
146 for interface_operand in &inst.operands[3..] {
147 let interface_id = interface_operand.unwrap_id_ref();
148 if let Some(generic) = specializer.generics.get(&interface_id)
149 && let Some(param_values) = &generic.param_values
150 && param_values.iter().all(|v| matches!(v, Value::Known(_)))
151 {
152 interface_concrete_instances.insert(Instance {
153 generic_id: interface_id,
154 generic_args: param_values
155 .iter()
156 .copied()
157 .map(|v| match v {
158 Value::Known(v) => v,
159 _ => unreachable!(),
160 })
161 .collect(),
162 });
163 }
164 }
165 }
166
167 let call_graph = CallGraph::collect(&module);
168 let mut non_generic_replacements = vec![];
169 for func_idx in call_graph.post_order() {
170 if let Some(replacements) = specializer.infer_function(&module.functions[func_idx]) {
171 non_generic_replacements.push((func_idx, replacements));
172 }
173 }
174
175 let mut expander = Expander::new(&specializer, module);
176
177 for interface_instance in interface_concrete_instances {
179 expander.alloc_instance_id(interface_instance);
180 }
181
182 debug!("non-generic replacements:");
186 for (func_idx, replacements) in non_generic_replacements {
187 let mut func = mem::replace(
188 &mut expander.builder.module_mut().functions[func_idx],
189 Function::new(),
190 );
191 let empty =
192 replacements.with_instance.is_empty() && replacements.with_concrete_or_param.is_empty();
193 if !empty {
194 debug!(" in %{}:", func.def_id().unwrap());
195 }
196 for (loc, operand) in
197 replacements.to_concrete(&[], |instance| expander.alloc_instance_id(instance))
198 {
199 debug!(" {operand} -> {loc:?}");
200 func.index_set(loc, operand.into());
201 }
202 expander.builder.module_mut().functions[func_idx] = func;
203 }
204 expander.propagate_instances();
205
206 if let Some(path) = dump_instances {
207 expander
208 .dump_instances(&mut std::fs::File::create(path).unwrap())
209 .unwrap();
210 }
211
212 expander.expand_module()
213}
214
215#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
218enum CopyOperand {
219 IdRef(Word),
220 StorageClass(StorageClass),
221}
222
223#[derive(Debug)]
224struct NotSupportedAsCopyOperand(
225 #[allow(dead_code)] Operand,
227);
228
229impl TryFrom<&Operand> for CopyOperand {
230 type Error = NotSupportedAsCopyOperand;
231 fn try_from(operand: &Operand) -> Result<Self, Self::Error> {
232 match *operand {
233 Operand::IdRef(id) => Ok(Self::IdRef(id)),
234 Operand::StorageClass(s) => Ok(Self::StorageClass(s)),
235 _ => Err(NotSupportedAsCopyOperand(operand.clone())),
236 }
237 }
238}
239
240impl From<CopyOperand> for Operand {
241 fn from(op: CopyOperand) -> Self {
242 match op {
243 CopyOperand::IdRef(id) => Self::IdRef(id),
244 CopyOperand::StorageClass(s) => Self::StorageClass(s),
245 }
246 }
247}
248
249impl fmt::Display for CopyOperand {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 match self {
252 Self::IdRef(id) => write!(f, "%{id}"),
253 Self::StorageClass(s) => write!(f, "{s:?}"),
254 }
255 }
256}
257
258#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
261enum Value<T> {
262 Unknown,
264
265 Known(CopyOperand),
267
268 SameAs(T),
274}
275
276impl<T> Value<T> {
277 fn map_var<U>(self, f: impl FnOnce(T) -> U) -> Value<U> {
278 match self {
279 Value::Unknown => Value::Unknown,
280 Value::Known(o) => Value::Known(o),
281 Value::SameAs(var) => Value::SameAs(f(var)),
282 }
283 }
284}
285
286#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
289struct Param(u32);
290
291impl fmt::Display for Param {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 write!(f, "${}", self.0)
294 }
295}
296
297impl Param {
298 fn range_iter(range: &Range<Self>) -> impl Iterator<Item = Self> + Clone {
301 (range.start.0..range.end.0).map(Self)
302 }
303}
304
305#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
307struct Instance<GA> {
308 generic_id: Word,
309 generic_args: GA,
310}
311
312impl<GA> Instance<GA> {
313 fn as_ref(&self) -> Instance<&GA> {
314 Instance {
315 generic_id: self.generic_id,
316 generic_args: &self.generic_args,
317 }
318 }
319
320 fn map_generic_args<T, U, GA2>(self, f: impl FnMut(T) -> U) -> Instance<GA2>
321 where
322 GA: IntoIterator<Item = T>,
323 GA2: std::iter::FromIterator<U>,
324 {
325 Instance {
326 generic_id: self.generic_id,
327 generic_args: self.generic_args.into_iter().map(f).collect(),
328 }
329 }
330
331 fn display<'a, T: fmt::Display, GAI: Iterator<Item = T> + Clone>(
333 &'a self,
334 f: impl FnOnce(&'a GA) -> GAI,
335 ) -> impl fmt::Display {
336 let &Self {
337 generic_id,
338 ref generic_args,
339 } = self;
340 let generic_args_iter = f(generic_args);
341 FmtBy(move |f| {
342 write!(f, "%{generic_id}<")?;
343 for (i, arg) in generic_args_iter.clone().enumerate() {
344 if i != 0 {
345 write!(f, ", ")?;
346 }
347 write!(f, "{arg}")?;
348 }
349 write!(f, ">")
350 })
351 }
352}
353
354#[derive(Copy, Clone, Debug, PartialEq, Eq)]
355enum InstructionLocation {
356 Module,
357 FnParam(usize),
358 FnBody {
359 block_idx: usize,
361
362 inst_idx: usize,
364 },
365}
366
367trait OperandIndexGetSet<I> {
368 #[allow(dead_code)]
370 fn index_get(&self, index: I) -> Operand;
371 fn index_set(&mut self, index: I, operand: Operand);
372}
373
374#[derive(Copy, Clone, Debug, PartialEq, Eq)]
375enum OperandIdx {
376 ResultType,
377 Input(usize),
378}
379
380impl OperandIndexGetSet<OperandIdx> for Instruction {
381 fn index_get(&self, idx: OperandIdx) -> Operand {
382 match idx {
383 OperandIdx::ResultType => Operand::IdRef(self.result_type.unwrap()),
384 OperandIdx::Input(i) => self.operands[i].clone(),
385 }
386 }
387 fn index_set(&mut self, idx: OperandIdx, operand: Operand) {
388 match idx {
389 OperandIdx::ResultType => self.result_type = Some(operand.unwrap_id_ref()),
390 OperandIdx::Input(i) => self.operands[i] = operand,
391 }
392 }
393}
394
395#[derive(Copy, Clone, Debug, PartialEq, Eq)]
396struct OperandLocation {
397 inst_loc: InstructionLocation,
398 operand_idx: OperandIdx,
399}
400
401impl OperandIndexGetSet<OperandLocation> for Instruction {
402 fn index_get(&self, loc: OperandLocation) -> Operand {
403 assert_eq!(loc.inst_loc, InstructionLocation::Module);
404 self.index_get(loc.operand_idx)
405 }
406 fn index_set(&mut self, loc: OperandLocation, operand: Operand) {
407 assert_eq!(loc.inst_loc, InstructionLocation::Module);
408 self.index_set(loc.operand_idx, operand);
409 }
410}
411
412impl OperandIndexGetSet<OperandLocation> for Function {
413 fn index_get(&self, loc: OperandLocation) -> Operand {
414 let inst = match loc.inst_loc {
415 InstructionLocation::Module => self.def.as_ref().unwrap(),
416 InstructionLocation::FnParam(i) => &self.parameters[i],
417 InstructionLocation::FnBody {
418 block_idx,
419 inst_idx,
420 } => &self.blocks[block_idx].instructions[inst_idx],
421 };
422 inst.index_get(loc.operand_idx)
423 }
424 fn index_set(&mut self, loc: OperandLocation, operand: Operand) {
425 let inst = match loc.inst_loc {
426 InstructionLocation::Module => self.def.as_mut().unwrap(),
427 InstructionLocation::FnParam(i) => &mut self.parameters[i],
428 InstructionLocation::FnBody {
429 block_idx,
430 inst_idx,
431 } => &mut self.blocks[block_idx].instructions[inst_idx],
432 };
433 inst.index_set(loc.operand_idx, operand);
434 }
435}
436
437#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
440enum ConcreteOrParam {
441 Concrete(CopyOperand),
442 Param(Param),
443}
444
445impl ConcreteOrParam {
446 fn apply_generic_args(self, generic_args: &[CopyOperand]) -> CopyOperand {
448 match self {
449 Self::Concrete(x) => x,
450 Self::Param(Param(i)) => generic_args[i as usize],
451 }
452 }
453}
454
455#[derive(Debug)]
456struct Replacements {
457 with_instance: IndexMap<Instance<SmallVec<[ConcreteOrParam; 4]>>, Vec<OperandLocation>>,
461
462 with_concrete_or_param: Vec<(OperandLocation, ConcreteOrParam)>,
464}
465
466impl Replacements {
467 fn to_concrete<'a>(
473 &'a self,
474 generic_args: &'a [CopyOperand],
475 mut concrete_instance_id: impl FnMut(Instance<SmallVec<[CopyOperand; 4]>>) -> Word + 'a,
476 ) -> impl Iterator<Item = (OperandLocation, CopyOperand)> + 'a {
477 self.with_instance
478 .iter()
479 .flat_map(move |(instance, locations)| {
480 let concrete = CopyOperand::IdRef(concrete_instance_id(
481 instance
482 .as_ref()
483 .map_generic_args(|x| x.apply_generic_args(generic_args)),
484 ));
485 locations.iter().map(move |&loc| (loc, concrete))
486 })
487 .chain(
488 self.with_concrete_or_param
489 .iter()
490 .map(move |&(loc, x)| (loc, x.apply_generic_args(generic_args))),
491 )
492 }
493}
494
495struct Generic {
502 param_count: u32,
503
504 def: Instruction,
510
511 param_values: Option<Vec<Value<Param>>>,
518
519 replacements: Replacements,
523}
524
525struct Specializer<S: Specialization> {
526 specialization: S,
527
528 debug_names: FxHashMap<Word, String>,
530
531 generics: IndexMap<Word, Generic>,
533
534 int_consts: FxHashMap<Word, u32>,
537}
538
539impl<S: Specialization> Specializer<S> {
540 fn params_needed_by(&self, operand: &Operand) -> (u32, Option<&Generic>) {
544 if self.specialization.specialize_operand(operand) {
545 (1, None)
547 } else if let Operand::IdRef(id) = operand {
548 self.generics
549 .get(id)
550 .map_or((0, None), |generic| (generic.param_count, Some(generic)))
551 } else {
552 (0, None)
553 }
554 }
555
556 fn collect_generics(&mut self, module: &Module) {
557 let types_global_values_and_functions = module
562 .types_global_values
563 .iter()
564 .chain(module.functions.iter().filter_map(|f| f.def.as_ref()));
565
566 let mut forward_declared_pointers = FxHashSet::default();
567 for inst in types_global_values_and_functions {
568 let result_id = if inst.class.opcode == Op::TypeForwardPointer {
569 forward_declared_pointers.insert(inst.operands[0].unwrap_id_ref());
570 inst.operands[0].unwrap_id_ref()
571 } else {
572 let result_id = inst.result_id.unwrap_or_else(|| {
573 unreachable!(
574 "Op{:?} is in `types_global_values` but not have a result ID",
575 inst.class.opcode
576 );
577 });
578 if forward_declared_pointers.remove(&result_id) {
579 assert_eq!(inst.class.opcode, Op::TypePointer);
584 continue;
585 }
586 result_id
587 };
588
589 if inst.class.opcode == Op::Constant
591 && let Operand::LiteralBit32(x) = inst.operands[0]
592 {
593 self.int_consts.insert(result_id, x);
594 }
595
596 let (param_count, param_values, replacements) = {
599 let mut infer_cx = InferCx::new(self);
600 infer_cx.instantiate_instruction(inst, InstructionLocation::Module);
601
602 let param_count = infer_cx.infer_var_values.len() as u32;
603
604 let param_values = infer_cx
606 .infer_var_values
607 .iter()
608 .map(|v| v.map_var(|InferVar(i)| Param(i)));
609 let param_values = if param_values.clone().any(|v| v != Value::Unknown) {
611 Some(param_values.collect())
612 } else {
613 None
614 };
615
616 (
617 param_count,
618 param_values,
619 infer_cx.into_replacements(..Param(param_count)),
620 )
621 };
622
623 if param_count > 0 {
625 self.generics.insert(
626 result_id,
627 Generic {
628 param_count,
629 def: inst.clone(),
630 param_values,
631 replacements,
632 },
633 );
634 }
635 }
636 }
637
638 fn infer_function(&mut self, func: &Function) -> Option<Replacements> {
642 let func_id = func.def_id().unwrap();
643
644 let param_count = self
645 .generics
646 .get(&func_id)
647 .map_or(0, |generic| generic.param_count);
648
649 let (param_values, replacements) = {
650 let mut infer_cx = InferCx::new(self);
651 infer_cx.instantiate_function(func);
652
653 let param_values = infer_cx.infer_var_values[..param_count as usize]
655 .iter()
656 .map(|v| v.map_var(|InferVar(i)| Param(i)));
657 let param_values = if param_values.clone().any(|v| v != Value::Unknown) {
659 Some(param_values.collect())
660 } else {
661 None
662 };
663
664 (
665 param_values,
666 infer_cx.into_replacements(..Param(param_count)),
667 )
668 };
669
670 if let Some(generic) = self.generics.get_mut(&func_id) {
671 assert!(generic.param_values.is_none());
675
676 generic.param_values = param_values;
677 generic.replacements = replacements;
678
679 None
680 } else {
681 Some(replacements)
682 }
683 }
684}
685
686#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
689struct InferVar(u32);
690
691impl fmt::Display for InferVar {
692 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
693 write!(f, "?{}", self.0)
694 }
695}
696
697impl InferVar {
698 fn range_iter(range: &Range<Self>) -> impl Iterator<Item = Self> + Clone {
701 (range.start.0..range.end.0).map(Self)
702 }
703}
704
705struct InferCx<'a, S: Specialization> {
706 specializer: &'a Specializer<S>,
707
708 infer_var_values: Vec<Value<InferVar>>,
714
715 type_of_result: IndexMap<Word, InferOperand>,
722
723 instantiated_operands: Vec<(OperandLocation, Instance<Range<InferVar>>)>,
726
727 inferred_operands: Vec<(OperandLocation, InferVar)>,
729}
730
731impl<'a, S: Specialization> InferCx<'a, S> {
732 fn new(specializer: &'a Specializer<S>) -> Self {
733 InferCx {
734 specializer,
735
736 infer_var_values: vec![],
737 type_of_result: IndexMap::new(),
738 instantiated_operands: vec![],
739 inferred_operands: vec![],
740 }
741 }
742}
743
744#[derive(Clone, Debug, PartialEq, Eq)]
745enum InferOperand {
746 Unknown,
747 Var(InferVar),
748 Concrete(CopyOperand),
749 Instance(Instance<Range<InferVar>>),
750}
751
752impl InferOperand {
753 fn from_operand_and_generic_args(
757 operand: &Operand,
758 generic_args: Range<InferVar>,
759 cx: &InferCx<'_, impl Specialization>,
760 ) -> (Self, Range<InferVar>) {
761 let (needed, generic) = cx.specializer.params_needed_by(operand);
762 let split = InferVar(generic_args.start.0 + needed);
763 let (generic_args, rest) = (generic_args.start..split, split..generic_args.end);
764 (
765 if generic.is_some() {
766 Self::Instance(Instance {
767 generic_id: operand.unwrap_id_ref(),
768 generic_args,
769 })
770 } else if needed == 0 {
771 CopyOperand::try_from(operand).map_or(Self::Unknown, Self::Concrete)
772 } else {
773 assert_eq!(needed, 1);
774 Self::Var(generic_args.start)
775 },
776 rest,
777 )
778 }
779
780 fn display_with_infer_var_values<'a>(
781 &'a self,
782 infer_var_value: impl Fn(InferVar) -> Value<InferVar> + Copy + 'a,
783 ) -> impl fmt::Display + '_ {
784 FmtBy(move |f| {
785 let var_with_value = |v| {
786 FmtBy(move |f| {
787 write!(f, "{v}")?;
788 match infer_var_value(v) {
789 Value::Unknown => Ok(()),
790 Value::Known(o) => write!(f, " = {o}"),
791 Value::SameAs(v) => write!(f, " = {v}"),
792 }
793 })
794 };
795 match self {
796 Self::Unknown => write!(f, "_"),
797 Self::Var(v) => write!(f, "{}", var_with_value(*v)),
798 Self::Concrete(o) => write!(f, "{o}"),
799 Self::Instance(instance) => write!(
800 f,
801 "{}",
802 instance.display(|generic_args| {
803 InferVar::range_iter(generic_args).map(var_with_value)
804 })
805 ),
806 }
807 })
808 }
809
810 fn display_with_infer_cx<'a>(
811 &'a self,
812 cx: &'a InferCx<'_, impl Specialization>,
813 ) -> impl fmt::Display + '_ {
814 self.display_with_infer_var_values(move |v| {
815 let get = |v: InferVar| cx.infer_var_values[v.0 as usize];
818 let mut value = get(v);
819 while let Value::SameAs(v) = value {
820 let next = get(v);
821 if next == Value::Unknown {
822 break;
823 }
824 value = next;
825 }
826 value
827 })
828 }
829}
830
831impl fmt::Display for InferOperand {
832 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
833 self.display_with_infer_var_values(|_| Value::Unknown)
834 .fmt(f)
835 }
836}
837
838#[derive(Copy, Clone, PartialEq, Eq)]
843enum InferOperandListTransform {
844 TypeOfId,
850}
851
852#[derive(Clone, PartialEq)]
853struct InferOperandList<'a> {
854 operands: &'a [Operand],
855
856 all_generic_args: Range<InferVar>,
859
860 transform: Option<InferOperandListTransform>,
861}
862
863impl<'a> InferOperandList<'a> {
864 fn split_first(
865 &self,
866 cx: &InferCx<'_, impl Specialization>,
867 ) -> Option<(InferOperand, InferOperandList<'a>)> {
868 let mut list = self.clone();
869 loop {
870 let (first_operand, rest) = list.operands.split_first()?;
871 list.operands = rest;
872
873 let (first, rest_args) = InferOperand::from_operand_and_generic_args(
874 first_operand,
875 list.all_generic_args.clone(),
876 cx,
877 );
878 list.all_generic_args = rest_args;
879
880 match self.transform {
882 None => {}
883
884 Some(InferOperandListTransform::TypeOfId) => {
886 if first_operand.id_ref_any().is_none() {
887 continue;
888 }
889 }
890 }
891
892 let first = match self.transform {
894 None => first,
895
896 Some(InferOperandListTransform::TypeOfId) => match first {
898 InferOperand::Concrete(CopyOperand::IdRef(id)) => cx
899 .type_of_result
900 .get(&id)
901 .cloned()
902 .unwrap_or(InferOperand::Unknown),
903 InferOperand::Unknown | InferOperand::Var(_) | InferOperand::Concrete(_) => {
904 InferOperand::Unknown
905 }
906 InferOperand::Instance(instance) => {
907 let generic = &cx.specializer.generics[&instance.generic_id];
908
909 let type_of_result = match generic.def.class.opcode {
916 Op::Function => Some(generic.def.operands[1].unwrap_id_ref()),
917 _ => generic.def.result_type,
918 };
919
920 match type_of_result {
921 Some(type_of_result) => {
922 InferOperand::from_operand_and_generic_args(
923 &Operand::IdRef(type_of_result),
924 instance.generic_args,
925 cx,
926 )
927 .0
928 }
929 None => InferOperand::Unknown,
930 }
931 }
932 },
933 };
934
935 return Some((first, list));
936 }
937 }
938
939 fn iter<'b>(
940 &self,
941 cx: &'b InferCx<'_, impl Specialization>,
942 ) -> impl Iterator<Item = InferOperand> + 'b
943 where
944 'a: 'b,
945 {
946 let mut list = self.clone();
947 iter::from_fn(move || {
948 let (next, rest) = list.split_first(cx)?;
949 list = rest;
950 Some(next)
951 })
952 }
953
954 fn display_with_infer_cx<'b>(
955 &'b self,
956 cx: &'b InferCx<'a, impl Specialization>,
957 ) -> impl fmt::Display + '_ {
958 FmtBy(move |f| {
959 f.debug_list()
960 .entries(self.iter(cx).map(|operand| {
961 FmtBy(move |f| write!(f, "{}", operand.display_with_infer_cx(cx)))
962 }))
963 .finish()
964 })
965 }
966}
967
968#[derive(Default)]
970struct SmallIntMap<A: smallvec::Array>(SmallVec<A>);
971
972impl<A: smallvec::Array> SmallIntMap<A> {
973 fn get(&self, i: usize) -> Option<&A::Item> {
974 self.0.get(i)
975 }
976
977 fn get_mut_or_default(&mut self, i: usize) -> &mut A::Item
978 where
979 A::Item: Default,
980 {
981 let needed = i + 1;
982 if self.0.len() < needed {
983 self.0.resize_with(needed, Default::default);
984 }
985 &mut self.0[i]
986 }
987}
988
989impl<A: smallvec::Array> IntoIterator for SmallIntMap<A> {
990 type Item = (usize, A::Item);
991 type IntoIter = iter::Enumerate<smallvec::IntoIter<A>>;
992 fn into_iter(self) -> Self::IntoIter {
993 self.0.into_iter().enumerate()
994 }
995}
996
997impl<'a, A: smallvec::Array> IntoIterator for &'a mut SmallIntMap<A> {
998 type Item = (usize, &'a mut A::Item);
999 type IntoIter = iter::Enumerate<slice::IterMut<'a, A::Item>>;
1000 fn into_iter(self) -> Self::IntoIter {
1001 self.0.iter_mut().enumerate()
1002 }
1003}
1004
1005#[derive(PartialEq)]
1006struct IndexCompositeMatch<'a> {
1007 indices: &'a [Operand],
1009
1010 leaf: InferOperand,
1012}
1013
1014#[must_use]
1016#[derive(Default)]
1017struct Match<'a> {
1018 ambiguous: bool,
1024
1025 storage_class_var_found: SmallIntMap<[SmallVec<[InferOperand; 2]>; 1]>,
1030
1031 ty_var_found: SmallIntMap<[SmallVec<[InferOperand; 4]>; 1]>,
1034
1035 index_composite_ty_var_found: SmallIntMap<[SmallVec<[IndexCompositeMatch<'a>; 1]>; 1]>,
1038
1039 ty_list_var_found: SmallIntMap<[SmallVec<[InferOperandList<'a>; 2]>; 1]>,
1042}
1043
1044impl<'a> Match<'a> {
1045 fn and(mut self, other: Self) -> Self {
1048 let Match {
1049 ambiguous,
1050 storage_class_var_found,
1051 ty_var_found,
1052 index_composite_ty_var_found,
1053 ty_list_var_found,
1054 } = &mut self;
1055
1056 *ambiguous |= other.ambiguous;
1057 for (i, other_found) in other.storage_class_var_found {
1058 storage_class_var_found
1059 .get_mut_or_default(i)
1060 .extend(other_found);
1061 }
1062 for (i, other_found) in other.ty_var_found {
1063 ty_var_found.get_mut_or_default(i).extend(other_found);
1064 }
1065 for (i, other_found) in other.index_composite_ty_var_found {
1066 index_composite_ty_var_found
1067 .get_mut_or_default(i)
1068 .extend(other_found);
1069 }
1070 for (i, other_found) in other.ty_list_var_found {
1071 ty_list_var_found.get_mut_or_default(i).extend(other_found);
1072 }
1073 self
1074 }
1075
1076 fn or(mut self, other: Self) -> Self {
1079 let Match {
1080 ambiguous,
1081 storage_class_var_found,
1082 ty_var_found,
1083 index_composite_ty_var_found,
1084 ty_list_var_found,
1085 } = &mut self;
1086
1087 *ambiguous |= other.ambiguous;
1088 for (i, self_found) in storage_class_var_found {
1089 let other_found = other
1090 .storage_class_var_found
1091 .get(i)
1092 .map_or(&[][..], |xs| &xs[..]);
1093 self_found.retain(|x| other_found.contains(x));
1094 }
1095 for (i, self_found) in ty_var_found {
1096 let other_found = other.ty_var_found.get(i).map_or(&[][..], |xs| &xs[..]);
1097 self_found.retain(|x| other_found.contains(x));
1098 }
1099 for (i, self_found) in index_composite_ty_var_found {
1100 let other_found = other
1101 .index_composite_ty_var_found
1102 .get(i)
1103 .map_or(&[][..], |xs| &xs[..]);
1104 self_found.retain(|x| other_found.contains(x));
1105 }
1106 for (i, self_found) in ty_list_var_found {
1107 let other_found = other.ty_list_var_found.get(i).map_or(&[][..], |xs| &xs[..]);
1108 self_found.retain(|x| other_found.contains(x));
1109 }
1110 self
1111 }
1112
1113 fn debug_with_infer_cx<'b, T: Specialization>(
1114 &'b self,
1115 cx: &'b InferCx<'a, T>,
1116 ) -> impl fmt::Debug + use<'a, 'b, T> {
1117 fn debug_var_found<'a, A: smallvec::Array<Item = T> + 'a, T: 'a, TD: fmt::Display>(
1118 var_found: &'a SmallIntMap<impl smallvec::Array<Item = SmallVec<A>>>,
1119 display: &'a impl Fn(&'a T) -> TD,
1120 ) -> impl Iterator<Item = impl fmt::Debug + 'a> + 'a {
1121 var_found
1122 .0
1123 .iter()
1124 .filter(|found| !found.is_empty())
1125 .map(move |found| {
1126 FmtBy(move |f| {
1127 let mut found = found.iter().map(display);
1128 write!(f, "{}", found.next().unwrap())?;
1129 for x in found {
1130 write!(f, " = {x}")?;
1131 }
1132 Ok(())
1133 })
1134 })
1135 }
1136 FmtBy(move |f| {
1137 let Self {
1138 ambiguous,
1139 storage_class_var_found,
1140 ty_var_found,
1141 index_composite_ty_var_found,
1142 ty_list_var_found,
1143 } = self;
1144 write!(f, "Match{} ", if *ambiguous { " (ambiguous)" } else { "" })?;
1145 let mut list = f.debug_list();
1146 list.entries(debug_var_found(storage_class_var_found, &move |operand| {
1147 operand.display_with_infer_cx(cx)
1148 }));
1149 list.entries(debug_var_found(ty_var_found, &move |operand| {
1150 operand.display_with_infer_cx(cx)
1151 }));
1152 list.entries(
1153 index_composite_ty_var_found
1154 .0
1155 .iter()
1156 .enumerate()
1157 .filter(|(_, found)| !found.is_empty())
1158 .flat_map(|(i, found)| found.iter().map(move |x| (i, x)))
1159 .map(move |(i, IndexCompositeMatch { indices, leaf })| {
1160 FmtBy(move |f| {
1161 match ty_var_found.get(i) {
1162 Some(found) if found.len() == 1 => {
1163 write!(f, "{}", found[0].display_with_infer_cx(cx))?;
1164 }
1165 found => {
1166 let found = found.map_or(&[][..], |xs| &xs[..]);
1167 write!(f, "(")?;
1168 for (j, operand) in found.iter().enumerate() {
1169 if j != 0 {
1170 write!(f, " = ")?;
1171 }
1172 write!(f, "{}", operand.display_with_infer_cx(cx))?;
1173 }
1174 write!(f, ")")?;
1175 }
1176 }
1177 for operand in &indices[..] {
1178 let maybe_idx = match operand {
1181 Operand::IdRef(id) => cx.specializer.int_consts.get(id),
1182 Operand::LiteralBit32(idx) => Some(idx),
1183 _ => None,
1184 };
1185 match maybe_idx {
1186 Some(idx) => write!(f, ".{idx}")?,
1187 None => write!(f, "[{operand}]")?,
1188 }
1189 }
1190 write!(f, " = {}", leaf.display_with_infer_cx(cx))
1191 })
1192 }),
1193 );
1194 list.entries(debug_var_found(ty_list_var_found, &move |list| {
1195 list.display_with_infer_cx(cx)
1196 }));
1197 list.finish()
1198 })
1199 }
1200}
1201
1202struct Unapplicable;
1204
1205impl<'a, S: Specialization> InferCx<'a, S> {
1206 #[allow(clippy::unused_self)] fn match_storage_class_pat(
1209 &self,
1210 pat: &StorageClassPat,
1211 storage_class: InferOperand,
1212 ) -> Match<'a> {
1213 match pat {
1214 StorageClassPat::Any => Match::default(),
1215 StorageClassPat::Var(i) => {
1216 let mut m = Match::default();
1217 m.storage_class_var_found
1218 .get_mut_or_default(*i)
1219 .push(storage_class);
1220 m
1221 }
1222 }
1223 }
1224
1225 fn match_ty_pat(&self, pat: &TyPat<'_>, ty: InferOperand) -> Result<Match<'a>, Unapplicable> {
1227 match pat {
1228 TyPat::Any => Ok(Match::default()),
1229 TyPat::Var(i) => {
1230 let mut m = Match::default();
1231 m.ty_var_found.get_mut_or_default(*i).push(ty);
1232 Ok(m)
1233 }
1234 TyPat::Either(a, b) => match self.match_ty_pat(a, ty.clone()) {
1235 Ok(m) if !m.ambiguous => Ok(m),
1236 a_result => match (a_result, self.match_ty_pat(b, ty)) {
1237 (Ok(ma), Ok(mb)) => Ok(ma.or(mb)),
1238 (Ok(m), _) | (_, Ok(m)) => Ok(m),
1239 (Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable),
1240 },
1241 },
1242 TyPat::IndexComposite(composite_pat) => match composite_pat {
1243 TyPat::Var(i) => {
1244 let mut m = Match::default();
1245 m.index_composite_ty_var_found.get_mut_or_default(*i).push(
1246 IndexCompositeMatch {
1247 indices: &[],
1251 leaf: ty,
1252 },
1253 );
1254 Ok(m)
1255 }
1256 _ => unreachable!(
1257 "`IndexComposite({:?})` isn't supported, only type variable
1258 patterns are (for the composite type), e.g. `IndexComposite(T)`",
1259 composite_pat
1260 ),
1261 },
1262 _ => {
1263 let instance = match ty {
1264 InferOperand::Unknown | InferOperand::Concrete(_) => {
1265 return Ok(Match {
1266 ambiguous: true,
1267 ..Match::default()
1268 });
1269 }
1270 InferOperand::Var(_) => return Err(Unapplicable),
1271 InferOperand::Instance(instance) => instance,
1272 };
1273 let generic = &self.specializer.generics[&instance.generic_id];
1274
1275 let ty_operands = InferOperandList {
1276 operands: &generic.def.operands,
1277 all_generic_args: instance.generic_args,
1278 transform: None,
1279 };
1280 let simple = |op, inner_pat| {
1281 if generic.def.class.opcode == op {
1282 self.match_ty_pat(inner_pat, ty_operands.split_first(self).unwrap().0)
1283 } else {
1284 Err(Unapplicable)
1285 }
1286 };
1287 match pat {
1288 TyPat::Any | TyPat::Var(_) | TyPat::Either(..) | TyPat::IndexComposite(_) => {
1289 unreachable!()
1290 }
1291
1292 TyPat::Void => unreachable!(),
1295
1296 TyPat::Pointer(storage_class_pat, pointee_pat) => {
1297 let mut ty_operands = ty_operands.iter(self);
1298 let (storage_class, pointee_ty) =
1299 (ty_operands.next().unwrap(), ty_operands.next().unwrap());
1300 Ok(self
1301 .match_storage_class_pat(storage_class_pat, storage_class)
1302 .and(self.match_ty_pat(pointee_pat, pointee_ty)?))
1303 }
1304 TyPat::Array(pat) => simple(Op::TypeArray, pat),
1305 TyPat::Vector(pat) => simple(Op::TypeVector, pat),
1306 TyPat::Vector4(pat) => match ty_operands.operands {
1307 [_, Operand::LiteralBit32(4)] => simple(Op::TypeVector, pat),
1308 _ => Err(Unapplicable),
1309 },
1310 TyPat::Matrix(pat) => simple(Op::TypeMatrix, pat),
1311 TyPat::Image(pat) => simple(Op::TypeImage, pat),
1312 TyPat::Pipe(_pat) => {
1313 if generic.def.class.opcode == Op::TypePipe {
1314 Ok(Match::default())
1315 } else {
1316 Err(Unapplicable)
1317 }
1318 }
1319 TyPat::SampledImage(pat) => simple(Op::TypeSampledImage, pat),
1320 TyPat::Struct(fields_pat) => {
1321 if generic.def.class.opcode == Op::TypeStruct {
1322 self.match_ty_list_pat(fields_pat, ty_operands)
1323 } else {
1324 Err(Unapplicable)
1325 }
1326 }
1327 TyPat::Function(ret_pat, params_pat) => {
1328 let (ret_ty, params_ty_list) = ty_operands.split_first(self).unwrap();
1329 Ok(self
1330 .match_ty_pat(ret_pat, ret_ty)?
1331 .and(self.match_ty_list_pat(params_pat, params_ty_list)?))
1332 }
1333 }
1334 }
1335 }
1336 }
1337
1338 fn match_ty_list_pat(
1340 &self,
1341 mut list_pat: &TyListPat<'_>,
1342 mut ty_list: InferOperandList<'a>,
1343 ) -> Result<Match<'a>, Unapplicable> {
1344 let mut m = Match::default();
1345
1346 while let TyListPat::Cons { first: pat, suffix } = list_pat {
1347 list_pat = suffix;
1348
1349 let (ty, rest) = ty_list.split_first(self).ok_or(Unapplicable)?;
1350 ty_list = rest;
1351
1352 m = m.and(self.match_ty_pat(pat, ty)?);
1353 }
1354
1355 match list_pat {
1356 TyListPat::Cons { .. } => unreachable!(),
1357
1358 TyListPat::Any => {}
1359 TyListPat::Var(i) => {
1360 m.ty_list_var_found.get_mut_or_default(*i).push(ty_list);
1361 }
1362 TyListPat::Repeat(repeat_list_pat) => {
1363 let mut tys = ty_list.iter(self).peekable();
1364 loop {
1365 let mut list_pat = repeat_list_pat;
1366 while let TyListPat::Cons { first: pat, suffix } = list_pat {
1367 m = m.and(self.match_ty_pat(pat, tys.next().ok_or(Unapplicable)?)?);
1368 list_pat = suffix;
1369 }
1370 assert!(matches!(list_pat, TyListPat::Nil));
1371 if tys.peek().is_none() {
1372 break;
1373 }
1374 }
1375 }
1376 TyListPat::Nil => {
1377 if ty_list.split_first(self).is_some() {
1378 return Err(Unapplicable);
1379 }
1380 }
1381 }
1382
1383 Ok(m)
1384 }
1385
1386 fn match_inst_sig(
1389 &self,
1390 sig: &InstSig<'_>,
1391 inst: &'a Instruction,
1392 inputs_generic_args: Range<InferVar>,
1393 result_type: Option<InferOperand>,
1394 ) -> Result<Match<'a>, Unapplicable> {
1395 let mut m = Match::default();
1396
1397 if let Some(pat) = sig.storage_class {
1398 let all_operands = InferOperandList {
1402 operands: &inst.operands,
1403 all_generic_args: inputs_generic_args.clone(),
1404 transform: None,
1405 };
1406 let storage_class = all_operands
1407 .iter(self)
1408 .zip(&inst.operands)
1409 .filter(|(_, original)| matches!(original, Operand::StorageClass(_)))
1410 .map(|(operand, _)| operand)
1411 .next()
1412 .ok_or(Unapplicable)?;
1413 m = m.and(self.match_storage_class_pat(pat, storage_class));
1414 }
1415
1416 let input_ty_list = InferOperandList {
1417 operands: &inst.operands,
1418 all_generic_args: inputs_generic_args,
1419 transform: Some(InferOperandListTransform::TypeOfId),
1420 };
1421
1422 m = m.and(self.match_ty_list_pat(sig.input_types, input_ty_list.clone())?);
1423
1424 match (sig.output_type, result_type) {
1425 (Some(pat), Some(result_type)) => {
1426 m = m.and(self.match_ty_pat(pat, result_type)?);
1427 }
1428 (None, None) => {}
1429 _ => return Err(Unapplicable),
1430 }
1431
1432 if !m.index_composite_ty_var_found.0.is_empty() {
1433 let composite_indices = {
1434 let mut ty_list = input_ty_list;
1436 let mut list_pat = sig.input_types;
1437 while let TyListPat::Cons { first: _, suffix } = list_pat {
1438 list_pat = suffix;
1439 ty_list = ty_list.split_first(self).ok_or(Unapplicable)?.1;
1440 }
1441
1442 assert_eq!(
1443 list_pat,
1444 &TyListPat::Any,
1445 "`IndexComposite` must have input types end in `..`"
1446 );
1447
1448 ty_list.operands
1453 };
1454
1455 for (_, found) in &mut m.index_composite_ty_var_found {
1457 for index_composite_match in found {
1458 let empty = mem::replace(&mut index_composite_match.indices, composite_indices);
1459 assert_eq!(empty, &[]);
1460 }
1461 }
1462 }
1463
1464 Ok(m)
1465 }
1466
1467 fn match_inst_sigs(
1470 &self,
1471 sigs: &[InstSig<'_>],
1472 inst: &'a Instruction,
1473 inputs_generic_args: Range<InferVar>,
1474 result_type: Option<InferOperand>,
1475 ) -> Result<Match<'a>, Unapplicable> {
1476 let mut result = Err(Unapplicable);
1477 for sig in sigs {
1478 result = match (
1479 result,
1480 self.match_inst_sig(sig, inst, inputs_generic_args.clone(), result_type.clone()),
1481 ) {
1482 (Err(Unapplicable), Ok(m)) if !m.ambiguous => return Ok(m),
1483 (Ok(a), Ok(b)) => Ok(a.or(b)),
1484 (Ok(m), _) | (_, Ok(m)) => Ok(m),
1485 (Err(Unapplicable), Err(Unapplicable)) => Err(Unapplicable),
1486 };
1487 }
1488 result
1489 }
1490}
1491
1492enum InferError {
1493 Conflict(InferOperand, InferOperand),
1496}
1497
1498impl InferError {
1499 fn report(self, inst: &Instruction) {
1500 match self {
1502 Self::Conflict(a, b) => {
1503 error!("inference conflict: {a:?} vs {b:?}");
1504 }
1505 }
1506 error!(" in ");
1507 if let Some(result_id) = inst.result_id {
1509 error!("%{result_id} = ");
1510 }
1511 error!("Op{:?}", inst.class.opcode);
1512 for operand in inst
1513 .result_type
1514 .map(Operand::IdRef)
1515 .iter()
1516 .chain(inst.operands.iter())
1517 {
1518 error!(" {operand}");
1519 }
1520 error!("");
1521
1522 std::process::exit(1);
1523 }
1524}
1525
1526impl<'a, S: Specialization> InferCx<'a, S> {
1527 fn resolve_infer_var(&mut self, v: InferVar) -> InferVar {
1531 match self.infer_var_values[v.0 as usize] {
1532 Value::Unknown | Value::Known(_) => v,
1533 Value::SameAs(next) => {
1534 let resolved = self.resolve_infer_var(next);
1535 if resolved != next {
1536 self.infer_var_values[v.0 as usize] = Value::SameAs(resolved);
1539 }
1540 resolved
1541 }
1542 }
1543 }
1544
1545 fn equate_infer_vars(&mut self, a: InferVar, b: InferVar) -> Result<InferVar, InferError> {
1548 let (a, b) = (self.resolve_infer_var(a), self.resolve_infer_var(b));
1549
1550 if a == b {
1551 return Ok(a);
1552 }
1553
1554 let (older, newer) = (a.min(b), a.max(b));
1556 let newer_value = mem::replace(
1557 &mut self.infer_var_values[newer.0 as usize],
1558 Value::SameAs(older),
1559 );
1560 match (self.infer_var_values[older.0 as usize], newer_value) {
1561 (Value::SameAs(_), _) | (_, Value::SameAs(_)) => unreachable!(),
1563
1564 (Value::Known(x), Value::Known(y)) => {
1566 if x != y {
1567 return Err(InferError::Conflict(
1568 InferOperand::Concrete(x),
1569 InferOperand::Concrete(y),
1570 ));
1571 }
1572 }
1573
1574 (Value::Unknown, Value::Known(_)) => {
1576 self.infer_var_values[older.0 as usize] = newer_value;
1577 }
1578
1579 (_, Value::Unknown) => {}
1580 }
1581
1582 Ok(older)
1583 }
1584
1585 fn equate_infer_var_ranges(
1587 &mut self,
1588 a: Range<InferVar>,
1589 b: Range<InferVar>,
1590 ) -> Result<Range<InferVar>, InferError> {
1591 if a == b {
1592 return Ok(a);
1593 }
1594
1595 assert_eq!(a.end.0 - a.start.0, b.end.0 - b.start.0);
1596
1597 for (a, b) in InferVar::range_iter(&a).zip(InferVar::range_iter(&b)) {
1598 self.equate_infer_vars(a, b)?;
1599 }
1600
1601 Ok(if a.start < b.start { a } else { b })
1605 }
1606
1607 fn equate_infer_operands(
1609 &mut self,
1610 a: InferOperand,
1611 b: InferOperand,
1612 ) -> Result<InferOperand, InferError> {
1613 if a == b {
1614 return Ok(a);
1615 }
1616
1617 #[allow(clippy::match_same_arms)]
1618 Ok(match (a.clone(), b.clone()) {
1619 (
1622 InferOperand::Instance(Instance {
1623 generic_id: a_id,
1624 generic_args: a_args,
1625 }),
1626 InferOperand::Instance(Instance {
1627 generic_id: b_id,
1628 generic_args: b_args,
1629 }),
1630 ) => {
1631 if a_id != b_id {
1632 return Err(InferError::Conflict(a, b));
1633 }
1634 InferOperand::Instance(Instance {
1635 generic_id: a_id,
1636 generic_args: self.equate_infer_var_ranges(a_args, b_args)?,
1637 })
1638 }
1639
1640 (InferOperand::Instance(_), _) | (_, InferOperand::Instance(_)) => {
1642 return Err(InferError::Conflict(a, b));
1643 }
1644
1645 (InferOperand::Var(a), InferOperand::Var(b)) => {
1647 InferOperand::Var(self.equate_infer_vars(a, b)?)
1648 }
1649
1650 (InferOperand::Var(v), InferOperand::Concrete(new))
1652 | (InferOperand::Concrete(new), InferOperand::Var(v)) => {
1653 let v = self.resolve_infer_var(v);
1654 match &mut self.infer_var_values[v.0 as usize] {
1655 Value::SameAs(_) => unreachable!(),
1657
1658 &mut Value::Known(old) => {
1659 if new != old {
1660 return Err(InferError::Conflict(
1661 InferOperand::Concrete(old),
1662 InferOperand::Concrete(new),
1663 ));
1664 }
1665 }
1666
1667 value @ Value::Unknown => *value = Value::Known(new),
1668 }
1669 InferOperand::Var(v)
1670 }
1671
1672 (InferOperand::Concrete(_), InferOperand::Concrete(_)) => {
1674 return Err(InferError::Conflict(a, b));
1676 }
1677
1678 (InferOperand::Unknown, x) | (x, InferOperand::Unknown) => x,
1681 })
1682 }
1683
1684 fn index_composite(&self, composite_ty: InferOperand, indices: &[Operand]) -> InferOperand {
1689 let mut ty = composite_ty;
1690 for idx in indices {
1691 let instance = match ty {
1692 InferOperand::Unknown | InferOperand::Concrete(_) | InferOperand::Var(_) => {
1693 return InferOperand::Unknown;
1694 }
1695 InferOperand::Instance(instance) => instance,
1696 };
1697 let generic = &self.specializer.generics[&instance.generic_id];
1698
1699 let ty_opcode = generic.def.class.opcode;
1700 let ty_operands = InferOperandList {
1701 operands: &generic.def.operands,
1702 all_generic_args: instance.generic_args,
1703 transform: None,
1704 };
1705
1706 let ty_operands_idx = match ty_opcode {
1707 Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix => 0,
1708 Op::TypeStruct => match idx {
1709 Operand::IdRef(id) => {
1710 *self.specializer.int_consts.get(id).unwrap_or_else(|| {
1711 unreachable!("non-constant `OpTypeStruct` field index {}", id);
1712 })
1713 }
1714 &Operand::LiteralBit32(i) => i,
1715 _ => {
1716 unreachable!("invalid `OpTypeStruct` field index operand {:?}", idx);
1717 }
1718 },
1719 _ => unreachable!("indexing non-composite type `Op{:?}`", ty_opcode),
1720 };
1721
1722 ty = ty_operands
1723 .iter(self)
1724 .nth(ty_operands_idx as usize)
1725 .unwrap_or_else(|| {
1726 unreachable!(
1727 "out of bounds index {} for `Op{:?}`",
1728 ty_operands_idx, ty_opcode
1729 );
1730 });
1731 }
1732 ty
1733 }
1734
1735 fn equate_match_findings(&mut self, m: Match<'_>) -> Result<(), InferError> {
1738 let Match {
1739 ambiguous: _,
1740
1741 storage_class_var_found,
1742 ty_var_found,
1743 index_composite_ty_var_found,
1744 ty_list_var_found,
1745 } = m;
1746
1747 for (_, found) in storage_class_var_found {
1748 let mut found = found.into_iter();
1749 if let Some(first) = found.next() {
1750 found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?;
1751 }
1752 }
1753
1754 for (i, found) in ty_var_found {
1755 let mut found = found.into_iter();
1756 if let Some(first) = found.next() {
1757 let equated_ty = found.try_fold(first, |a, b| self.equate_infer_operands(a, b))?;
1758
1759 let index_composite_found = index_composite_ty_var_found
1762 .get(i)
1763 .map_or(&[][..], |xs| &xs[..]);
1764 for IndexCompositeMatch { indices, leaf } in index_composite_found {
1765 let indexing_result_ty = self.index_composite(equated_ty.clone(), indices);
1766 self.equate_infer_operands(indexing_result_ty, leaf.clone())?;
1767 }
1768 }
1769 }
1770
1771 for (_, mut found) in ty_list_var_found {
1772 if let Some((first_list, other_lists)) = found.split_first_mut() {
1773 while let Some((first, rest)) = first_list.split_first(self) {
1776 *first_list = rest;
1777
1778 other_lists.iter_mut().try_fold(first, |a, b_list| {
1779 let (b, rest) = b_list
1780 .split_first(self)
1781 .expect("list length mismatch (invalid SPIR-V?)");
1782 *b_list = rest;
1783 self.equate_infer_operands(a, b)
1784 })?;
1785 }
1786
1787 for other_list in other_lists {
1788 assert!(
1789 other_list.split_first(self).is_none(),
1790 "list length mismatch (invalid SPIR-V?)"
1791 );
1792 }
1793 }
1794 }
1795
1796 Ok(())
1797 }
1798
1799 fn record_instantiated_operand(&mut self, loc: OperandLocation, operand: InferOperand) {
1802 match operand {
1803 InferOperand::Var(v) => {
1804 self.inferred_operands.push((loc, v));
1805 }
1806 InferOperand::Instance(instance) => {
1807 self.instantiated_operands.push((loc, instance));
1808 }
1809 InferOperand::Unknown | InferOperand::Concrete(_) => {}
1810 }
1811 }
1812
1813 fn instantiate_instruction(&mut self, inst: &'a Instruction, inst_loc: InstructionLocation) {
1817 let mut all_generic_args = {
1818 let next_infer_var = InferVar(self.infer_var_values.len().try_into().unwrap());
1819 next_infer_var..next_infer_var
1820 };
1821
1822 let (instantiate_result_type, record_fn_ret_ty, type_of_result) = match inst.class.opcode {
1829 Op::Function => (
1830 None,
1831 inst.result_type,
1832 Some(inst.operands[1].unwrap_id_ref()),
1833 ),
1834 _ => (inst.result_type, None, inst.result_type),
1835 };
1836
1837 for (operand_idx, operand) in instantiate_result_type
1838 .map(Operand::IdRef)
1839 .iter()
1840 .map(|o| (OperandIdx::ResultType, o))
1841 .chain(
1842 inst.operands
1843 .iter()
1844 .enumerate()
1845 .map(|(i, o)| (OperandIdx::Input(i), o)),
1846 )
1847 {
1848 let (operand, rest) = InferOperand::from_operand_and_generic_args(
1850 operand,
1851 all_generic_args.end..InferVar(u32::MAX),
1852 self,
1853 );
1854 let generic_args = all_generic_args.end..rest.start;
1855 all_generic_args.end = generic_args.end;
1856
1857 let generic = match &operand {
1858 InferOperand::Instance(instance) => {
1859 Some(&self.specializer.generics[&instance.generic_id])
1860 }
1861 _ => None,
1862 };
1863
1864 match generic {
1867 Some(Generic {
1868 param_values: Some(values),
1869 ..
1870 }) => self.infer_var_values.extend(
1871 values
1872 .iter()
1873 .map(|v| v.map_var(|Param(p)| InferVar(generic_args.start.0 + p))),
1874 ),
1875
1876 _ => {
1877 self.infer_var_values
1878 .extend(InferVar::range_iter(&generic_args).map(|_| Value::Unknown));
1879 }
1880 }
1881
1882 self.record_instantiated_operand(
1883 OperandLocation {
1884 inst_loc,
1885 operand_idx,
1886 },
1887 operand,
1888 );
1889 }
1890
1891 if let Some(ret_ty) = record_fn_ret_ty {
1893 let (ret_ty, _) = InferOperand::from_operand_and_generic_args(
1894 &Operand::IdRef(ret_ty),
1895 all_generic_args.clone(),
1896 self,
1897 );
1898 self.record_instantiated_operand(
1899 OperandLocation {
1900 inst_loc,
1901 operand_idx: OperandIdx::ResultType,
1902 },
1903 ret_ty,
1904 );
1905 }
1906
1907 let (type_of_result, inputs_generic_args) = match type_of_result {
1909 Some(type_of_result) => {
1910 let (type_of_result, rest) = InferOperand::from_operand_and_generic_args(
1911 &Operand::IdRef(type_of_result),
1912 all_generic_args.clone(),
1913 self,
1914 );
1915 (
1916 Some(type_of_result),
1917 match inst.class.opcode {
1919 Op::Function => all_generic_args,
1920 _ => rest,
1921 },
1922 )
1923 }
1924 None => (None, all_generic_args),
1925 };
1926
1927 let debug_dump_if_enabled = |cx: &Self, prefix| {
1928 let result_type = match inst.class.opcode {
1929 Op::Function => Some(
1931 InferOperand::from_operand_and_generic_args(
1932 &Operand::IdRef(inst.result_type.unwrap()),
1933 inputs_generic_args.clone(),
1934 cx,
1935 )
1936 .0,
1937 ),
1938 _ => type_of_result.clone(),
1939 };
1940 let inputs = InferOperandList {
1941 operands: &inst.operands,
1942 all_generic_args: inputs_generic_args.clone(),
1943 transform: None,
1944 };
1945
1946 if inst_loc != InstructionLocation::Module {
1947 debug!(" ");
1948 }
1949 debug!("{prefix}");
1950 if let Some(result_id) = inst.result_id {
1951 debug!("%{result_id} = ");
1952 }
1953 debug!("Op{:?}", inst.class.opcode);
1954 for operand in result_type.into_iter().chain(inputs.iter(cx)) {
1955 debug!(" {}", operand.display_with_infer_cx(cx));
1956 }
1957 debug!("");
1958 };
1959
1960 if let Some(sigs) = spirv_type_constraints::instruction_signatures(inst.class.opcode) {
1962 assert_ne!(inst.class.opcode, Op::Function);
1965
1966 debug_dump_if_enabled(self, " -> ");
1967
1968 let m = match self.match_inst_sigs(
1969 sigs,
1970 inst,
1971 inputs_generic_args.clone(),
1972 type_of_result.clone(),
1973 ) {
1974 Ok(m) => m,
1975
1976 Err(Unapplicable) => unreachable!(
1982 "spirv_type_constraints(Op{:?}) = `{:?}` doesn't match `{:?}`",
1983 inst.class.opcode, sigs, inst
1984 ),
1985 };
1986
1987 if inst_loc != InstructionLocation::Module {
1988 debug!(" ");
1989 }
1990 debug!(" found {:?}", m.debug_with_infer_cx(self));
1991
1992 if let Err(e) = self.equate_match_findings(m) {
1993 e.report(inst);
1994 }
1995
1996 debug_dump_if_enabled(self, " <- ");
1997 } else {
1998 debug_dump_if_enabled(self, "");
1999 }
2000
2001 if let Some(type_of_result) = type_of_result {
2002 match type_of_result {
2005 InferOperand::Var(_) | InferOperand::Instance(_) => {
2006 self.type_of_result
2007 .insert(inst.result_id.unwrap(), type_of_result);
2008 }
2009 InferOperand::Unknown | InferOperand::Concrete(_) => {}
2010 }
2011 }
2012 }
2013
2014 fn instantiate_function(&mut self, func: &'a Function) {
2017 let func_id = func.def_id().unwrap();
2018
2019 debug!("");
2020 debug!("specializer::instantiate_function(%{func_id}");
2021 if let Some(name) = self.specializer.debug_names.get(&func_id) {
2022 debug!(" {name}");
2023 }
2024 debug!("):");
2025
2026 assert!(self.infer_var_values.is_empty());
2030 self.instantiate_instruction(func.def.as_ref().unwrap(), InstructionLocation::Module);
2031
2032 debug!("infer body {{");
2033
2034 let ret_ty = match self.type_of_result.get(&func_id).cloned() {
2037 Some(InferOperand::Instance(instance)) => {
2038 let generic = &self.specializer.generics[&instance.generic_id];
2039 assert_eq!(generic.def.class.opcode, Op::TypeFunction);
2040
2041 let (ret_ty, mut params_ty_list) = InferOperandList {
2042 operands: &generic.def.operands,
2043 all_generic_args: instance.generic_args,
2044 transform: None,
2045 }
2046 .split_first(self)
2047 .unwrap();
2048
2049 let mut params = func.parameters.iter().enumerate();
2051 while let Some((param_ty, rest)) = params_ty_list.split_first(self) {
2052 params_ty_list = rest;
2053
2054 let (i, param) = params.next().unwrap();
2055 assert_eq!(param.class.opcode, Op::FunctionParameter);
2056
2057 debug!(
2058 " %{} = Op{:?} {}",
2059 param.result_id.unwrap(),
2060 param.class.opcode,
2061 param_ty.display_with_infer_cx(self)
2062 );
2063
2064 self.record_instantiated_operand(
2065 OperandLocation {
2066 inst_loc: InstructionLocation::FnParam(i),
2067 operand_idx: OperandIdx::ResultType,
2068 },
2069 param_ty.clone(),
2070 );
2071 match param_ty {
2072 InferOperand::Var(_) | InferOperand::Instance(_) => {
2073 self.type_of_result
2074 .insert(param.result_id.unwrap(), param_ty);
2075 }
2076 InferOperand::Unknown | InferOperand::Concrete(_) => {}
2077 }
2078 }
2079 assert_eq!(params.next(), None);
2080
2081 Some(ret_ty)
2082 }
2083
2084 _ => None,
2085 };
2086
2087 for (block_idx, block) in func.blocks.iter().enumerate() {
2088 for (inst_idx, inst) in block.instructions.iter().enumerate() {
2089 match inst.class.opcode {
2092 Op::ReturnValue => {
2093 let ret_val_id = inst.operands[0].unwrap_id_ref();
2094 if let (Some(expected), Some(found)) = (
2095 ret_ty.clone(),
2096 self.type_of_result.get(&ret_val_id).cloned(),
2097 ) && let Err(e) = self.equate_infer_operands(expected, found)
2098 {
2099 e.report(inst);
2100 }
2101 }
2102
2103 Op::Return => {}
2104
2105 _ => self.instantiate_instruction(
2106 inst,
2107 InstructionLocation::FnBody {
2108 block_idx,
2109 inst_idx,
2110 },
2111 ),
2112 }
2113 }
2114 }
2115
2116 debug!("}}");
2117 if let Some(func_ty) = self.type_of_result.get(&func_id) {
2118 debug!(" -> %{}: {}", func_id, func_ty.display_with_infer_cx(self));
2119 }
2120 debug!("");
2121 }
2122
2123 fn resolve_infer_var_to_concrete_or_param(
2129 &mut self,
2130 v: InferVar,
2131 generic_params: RangeTo<Param>,
2132 ) -> ConcreteOrParam {
2133 let v = self.resolve_infer_var(v);
2134 let InferVar(i) = v;
2135 match self.infer_var_values[i as usize] {
2136 Value::SameAs(_) => unreachable!(),
2138
2139 Value::Unknown => {
2140 if i < generic_params.end.0 {
2141 ConcreteOrParam::Param(Param(i))
2142 } else {
2143 ConcreteOrParam::Concrete(
2144 CopyOperand::try_from(&self.specializer.specialization.concrete_fallback())
2145 .unwrap(),
2146 )
2147 }
2148 }
2149 Value::Known(x) => ConcreteOrParam::Concrete(x),
2150 }
2151 }
2152
2153 fn into_replacements(mut self, generic_params: RangeTo<Param>) -> Replacements {
2158 let mut with_instance: IndexMap<_, Vec<_>> = IndexMap::new();
2159 for (loc, instance) in mem::take(&mut self.instantiated_operands) {
2160 with_instance
2161 .entry(Instance {
2162 generic_id: instance.generic_id,
2163 generic_args: InferVar::range_iter(&instance.generic_args)
2164 .map(|v| self.resolve_infer_var_to_concrete_or_param(v, generic_params))
2165 .collect(),
2166 })
2167 .or_default()
2168 .push(loc);
2169 }
2170
2171 let with_concrete_or_param = mem::take(&mut self.inferred_operands)
2172 .into_iter()
2173 .map(|(loc, v)| {
2174 (
2175 loc,
2176 self.resolve_infer_var_to_concrete_or_param(v, generic_params),
2177 )
2178 })
2179 .collect();
2180
2181 Replacements {
2182 with_instance,
2183 with_concrete_or_param,
2184 }
2185 }
2186}
2187
2188struct Expander<'a, S: Specialization> {
2193 specializer: &'a Specializer<S>,
2194
2195 builder: Builder,
2196
2197 instances: BTreeMap<Instance<SmallVec<[CopyOperand; 4]>>, Word>,
2203
2204 propagate_instances_queue: VecDeque<Instance<SmallVec<[CopyOperand; 4]>>>,
2208}
2209
2210impl<'a, S: Specialization> Expander<'a, S> {
2211 fn new(specializer: &'a Specializer<S>, module: Module) -> Self {
2212 Expander {
2213 specializer,
2214
2215 builder: Builder::new_from_module(module),
2216
2217 instances: BTreeMap::new(),
2218 propagate_instances_queue: VecDeque::new(),
2219 }
2220 }
2221
2222 fn all_instances_of(
2227 &self,
2228 generic_id: Word,
2229 ) -> std::collections::btree_map::Range<'_, Instance<SmallVec<[CopyOperand; 4]>>, Word> {
2230 let first_instance_of = |generic_id| Instance {
2231 generic_id,
2232 generic_args: SmallVec::new(),
2233 };
2234 self.instances
2235 .range(first_instance_of(generic_id)..first_instance_of(generic_id + 1))
2236 }
2237
2238 fn alloc_instance_id(&mut self, instance: Instance<SmallVec<[CopyOperand; 4]>>) -> Word {
2242 use std::collections::btree_map::Entry;
2243
2244 match self.instances.entry(instance) {
2245 Entry::Occupied(entry) => *entry.get(),
2246 Entry::Vacant(entry) => {
2247 let instance = entry.key().clone();
2250
2251 self.propagate_instances_queue.push_back(instance);
2252 *entry.insert(self.builder.id())
2253 }
2254 }
2255 }
2256
2257 fn propagate_instances(&mut self) {
2264 while let Some(instance) = self.propagate_instances_queue.pop_back() {
2265 for _ in self.specializer.generics[&instance.generic_id]
2267 .replacements
2268 .to_concrete(&instance.generic_args, |i| self.alloc_instance_id(i))
2269 {}
2270 }
2271 }
2272
2273 fn expand_module(mut self) -> Module {
2277 self.propagate_instances();
2280
2281 let module = self.builder.module_mut();
2283 let mut entry_points = mem::take(&mut module.entry_points);
2284 let debug_names = mem::take(&mut module.debug_names);
2285 let annotations = mem::take(&mut module.annotations);
2286 let types_global_values = mem::take(&mut module.types_global_values);
2287 let functions = mem::take(&mut module.functions);
2288
2289 for inst in &mut entry_points {
2292 let func_id = inst.operands[1].unwrap_id_ref();
2293 assert!(
2294 !self.specializer.generics.contains_key(&func_id),
2295 "entry-point %{func_id} shouldn't be \"generic\""
2296 );
2297
2298 for interface_operand in &mut inst.operands[3..] {
2299 let interface_id = interface_operand.unwrap_id_ref();
2300 let mut instances = self.all_instances_of(interface_id);
2301 match (instances.next(), instances.next()) {
2302 (None, _) => unreachable!(
2303 "entry-point %{} has overly-\"generic\" \
2304 interface variable %{}, with no instances",
2305 func_id, interface_id
2306 ),
2307 (Some(_), Some(_)) => unreachable!(
2308 "entry-point %{} has overly-\"generic\" \
2309 interface variable %{}, with too many instances: {:?}",
2310 func_id,
2311 interface_id,
2312 FmtBy(|f| f
2313 .debug_list()
2314 .entries(self.all_instances_of(interface_id).map(
2315 |(instance, _)| FmtBy(move |f| write!(
2316 f,
2317 "{}",
2318 instance.display(|generic_args| generic_args.iter().copied())
2319 ))
2320 ))
2321 .finish())
2322 ),
2323 (Some((_, &instance_id)), None) => {
2324 *interface_operand = Operand::IdRef(instance_id);
2325 }
2326 }
2327 }
2328 }
2329
2330 let expand_debug_or_annotation = |insts: Vec<Instruction>| {
2335 let mut expanded_insts = Vec::with_capacity(insts.len().next_power_of_two());
2336 for inst in insts {
2337 if let [Operand::IdRef(target), ..] = inst.operands[..]
2338 && self.specializer.generics.contains_key(&target)
2339 {
2340 expanded_insts.extend(self.all_instances_of(target).map(
2341 |(_, &instance_id)| {
2342 let mut expanded_inst = inst.clone();
2343 expanded_inst.operands[0] = Operand::IdRef(instance_id);
2344 expanded_inst
2345 },
2346 ));
2347 continue;
2348 }
2349 expanded_insts.push(inst);
2350 }
2351 expanded_insts
2352 };
2353
2354 let expanded_debug_names = expand_debug_or_annotation(debug_names);
2356
2357 let mut expanded_annotations = expand_debug_or_annotation(annotations);
2359
2360 let mut expanded_types_global_values =
2362 Vec::with_capacity(types_global_values.len().next_power_of_two());
2363 for inst in types_global_values {
2364 if let Some(result_id) = inst.result_id
2365 && let Some(generic) = self.specializer.generics.get(&result_id)
2366 {
2367 expanded_types_global_values.extend(self.all_instances_of(result_id).map(
2368 |(instance, &instance_id)| {
2369 let mut expanded_inst = inst.clone();
2370 expanded_inst.result_id = Some(instance_id);
2371 for (loc, operand) in generic
2372 .replacements
2373 .to_concrete(&instance.generic_args, |i| self.instances[&i])
2374 {
2375 expanded_inst.index_set(loc, operand.into());
2376 }
2377 expanded_inst
2378 },
2379 ));
2380 continue;
2381 }
2382 expanded_types_global_values.push(inst);
2383 }
2384
2385 let mut expanded_functions = Vec::with_capacity(functions.len().next_power_of_two());
2387 for func in functions {
2388 let func_id = func.def_id().unwrap();
2389 if let Some(generic) = self.specializer.generics.get(&func_id) {
2390 let old_expanded_functions_len = expanded_functions.len();
2391 expanded_functions.extend(self.all_instances_of(func_id).map(
2392 |(instance, &instance_id)| {
2393 let mut expanded_func = func.clone();
2394 expanded_func.def.as_mut().unwrap().result_id = Some(instance_id);
2395 for (loc, operand) in generic
2396 .replacements
2397 .to_concrete(&instance.generic_args, |i| self.instances[&i])
2398 {
2399 expanded_func.index_set(loc, operand.into());
2400 }
2401 expanded_func
2402 },
2403 ));
2404
2405 let newly_expanded_functions =
2412 &mut expanded_functions[old_expanded_functions_len..];
2413 if newly_expanded_functions.len() > 1 {
2414 let mut rewrite_rules = FxHashMap::default();
2417
2418 for func in newly_expanded_functions {
2419 rewrite_rules.clear();
2420
2421 rewrite_rules.extend(func.parameters.iter_mut().map(|param| {
2422 let old_id = param.result_id.unwrap();
2423 let new_id = self.builder.id();
2424
2425 param.result_id = Some(new_id);
2429
2430 (old_id, new_id)
2431 }));
2432 rewrite_rules.extend(
2433 func.blocks
2434 .iter()
2435 .flat_map(|b| b.label.iter().chain(b.instructions.iter()))
2436 .filter_map(|inst| inst.result_id)
2437 .map(|old_id| (old_id, self.builder.id())),
2438 );
2439
2440 super::apply_rewrite_rules(&rewrite_rules, &mut func.blocks);
2441
2442 for annotation_idx in 0..expanded_annotations.len() {
2444 let inst = &expanded_annotations[annotation_idx];
2445 if let [Operand::IdRef(target), ..] = inst.operands[..]
2446 && let Some(&rewritten_target) = rewrite_rules.get(&target)
2447 {
2448 let mut expanded_inst = inst.clone();
2449 expanded_inst.operands[0] = Operand::IdRef(rewritten_target);
2450 expanded_annotations.push(expanded_inst);
2451 }
2452 }
2453 }
2454 }
2455
2456 continue;
2457 }
2458 expanded_functions.push(func);
2459 }
2460
2461 assert!(self.propagate_instances_queue.is_empty());
2464
2465 let module = self.builder.module_mut();
2466 module.entry_points = entry_points;
2467 module.debug_names = expanded_debug_names;
2468 module.annotations = expanded_annotations;
2469 module.types_global_values = expanded_types_global_values;
2470 module.functions = expanded_functions;
2471
2472 self.builder.module()
2473 }
2474
2475 fn dump_instances(&self, w: &mut impl io::Write) -> io::Result<()> {
2476 writeln!(w, "; All specializer \"generic\"s and their instances:")?;
2477 writeln!(w)?;
2478
2479 for (&generic_id, generic) in &self.specializer.generics {
2481 if let Some(name) = self.specializer.debug_names.get(&generic_id) {
2482 writeln!(w, "; {name}")?;
2483 }
2484
2485 write!(
2486 w,
2487 "{} = Op{:?}",
2488 Instance {
2489 generic_id,
2490 generic_args: Param(0)..Param(generic.param_count)
2491 }
2492 .display(Param::range_iter),
2493 generic.def.class.opcode
2494 )?;
2495 let mut next_param = Param(0);
2496 for operand in generic
2497 .def
2498 .result_type
2499 .map(Operand::IdRef)
2500 .iter()
2501 .chain(generic.def.operands.iter())
2502 {
2503 write!(w, " ")?;
2504 let (needed, used_generic) = self.specializer.params_needed_by(operand);
2505 let params = next_param..Param(next_param.0 + needed);
2506
2507 if generic.def.class.opcode != Op::Function {
2509 next_param = params.end;
2510 }
2511
2512 if used_generic.is_some() {
2513 write!(
2514 w,
2515 "{}",
2516 Instance {
2517 generic_id: operand.unwrap_id_ref(),
2518 generic_args: params
2519 }
2520 .display(Param::range_iter)
2521 )?;
2522 } else if needed == 1 {
2523 write!(w, "{}", params.start)?;
2524 } else {
2525 write!(w, "{operand}")?;
2526 }
2527 }
2528 writeln!(w)?;
2529
2530 if let Some(param_values) = &generic.param_values {
2531 write!(w, " where")?;
2532 for (i, v) in param_values.iter().enumerate() {
2533 let p = Param(i as u32);
2534 match v {
2535 Value::Unknown => {}
2536 Value::Known(o) => write!(w, " {p} = {o},")?,
2537 Value::SameAs(q) => write!(w, " {p} = {q},")?,
2538 }
2539 }
2540 writeln!(w)?;
2541 }
2542
2543 for (instance, instance_id) in self.all_instances_of(generic_id) {
2544 assert_eq!(instance.generic_id, generic_id);
2545 writeln!(
2546 w,
2547 " %{} = {}",
2548 instance_id,
2549 instance.display(|generic_args| generic_args.iter().copied())
2550 )?;
2551 }
2552
2553 writeln!(w)?;
2554 }
2555 Ok(())
2556 }
2557}