1use crate::spv::{self, spec};
4use std::borrow::Cow;
5use std::path::Path;
6use std::{fs, io, iter, slice};
7
8struct OperandEmitter<'a> {
10 imms: iter::Copied<slice::Iter<'a, spv::Imm>>,
12
13 ids: iter::Copied<slice::Iter<'a, spv::Id>>,
15
16 out: &'a mut Vec<u32>,
18}
19
20enum OperandEmitError {
21 NotEnoughImms,
23
24 NotEnoughIds,
26
27 TooManyImms,
29
30 TooManyIds,
32
33 UnsupportedEnumerand(spec::OperandKind, u32),
35}
36
37impl OperandEmitError {
38 fn message(&self) -> Cow<'static, str> {
40 match *self {
41 Self::NotEnoughImms => "truncated instruction (immediates)".into(),
42 Self::NotEnoughIds => "truncated instruction (IDs)".into(),
43 Self::TooManyImms => "overlong instruction (immediates)".into(),
44 Self::TooManyIds => "overlong instruction (IDs)".into(),
45 Self::UnsupportedEnumerand(kind, word) => {
47 let (name, def) = kind.name_and_def();
48 match def {
49 spec::OperandKindDef::BitEnum { bits, .. } => {
50 let unsupported = spec::BitIdx::of_all_set_bits(word)
51 .filter(|&bit_idx| bits.get(bit_idx).is_none())
52 .fold(0u32, |x, i| x | (1 << i.0));
53 format!("unsupported {name} bit-pattern 0x{unsupported:08x}").into()
54 }
55
56 spec::OperandKindDef::ValueEnum { .. } => {
57 format!("unsupported {name} value {word}").into()
58 }
59
60 _ => unreachable!(),
61 }
62 }
63 }
64 }
65}
66
67impl OperandEmitter<'_> {
68 fn is_exhausted(&mut self) -> bool {
69 self.imms.len() == 0 && self.ids.len() == 0
72 }
73
74 fn enumerant_params(&mut self, enumerant: &spec::Enumerant) -> Result<(), OperandEmitError> {
75 for (mode, kind) in enumerant.all_params() {
76 if mode == spec::OperandMode::Optional && self.is_exhausted() {
77 break;
78 }
79 self.operand(kind)?;
80 }
81
82 Ok(())
83 }
84
85 fn operand(&mut self, kind: spec::OperandKind) -> Result<(), OperandEmitError> {
86 use OperandEmitError as Error;
87
88 let mut get_enum_word = || match self.imms.next() {
89 Some(spv::Imm::Short(found_kind, word)) => {
90 assert_eq!(kind, found_kind);
91 Ok(word)
92 }
93 Some(spv::Imm::LongStart(..) | spv::Imm::LongCont(..)) => unreachable!(),
94 None => Err(Error::NotEnoughImms),
95 };
96
97 match kind.def() {
98 spec::OperandKindDef::BitEnum { bits, .. } => {
99 let word = get_enum_word()?;
100 self.out.push(word);
101
102 for bit_idx in spec::BitIdx::of_all_set_bits(word) {
103 let bit_def =
104 bits.get(bit_idx).ok_or(Error::UnsupportedEnumerand(kind, word))?;
105 self.enumerant_params(bit_def)?;
106 }
107 }
108 spec::OperandKindDef::ValueEnum { variants } => {
109 let word = get_enum_word()?;
110 self.out.push(word);
111
112 let variant_def = u16::try_from(word)
113 .ok()
114 .and_then(|v| variants.get(v))
115 .ok_or(Error::UnsupportedEnumerand(kind, word))?;
116 self.enumerant_params(variant_def)?;
117 }
118 spec::OperandKindDef::Id => {
119 self.out.push(self.ids.next().ok_or(Error::NotEnoughIds)?.get());
120 }
121 spec::OperandKindDef::Literal { .. } => {
122 match self.imms.next().ok_or(Error::NotEnoughImms)? {
123 spv::Imm::Short(found_kind, word) => {
124 assert_eq!(kind, found_kind);
125 self.out.push(word);
126 }
127 spv::Imm::LongStart(found_kind, word) => {
128 assert_eq!(kind, found_kind);
129 self.out.push(word);
130 while let Some(spv::Imm::LongCont(cont_kind, word)) =
131 self.imms.clone().next()
132 {
133 self.imms.next();
134 assert_eq!(kind, cont_kind);
135 self.out.push(word);
136 }
137 }
138 spv::Imm::LongCont(..) => unreachable!(),
139 }
140 }
141 }
142
143 Ok(())
144 }
145
146 fn inst_operands(mut self, def: &spec::InstructionDef) -> Result<(), OperandEmitError> {
147 use OperandEmitError as Error;
148
149 for (mode, kind) in def.all_operands() {
150 if mode == spec::OperandMode::Optional && self.is_exhausted() {
151 break;
152 }
153 self.operand(kind)?;
154 }
155
156 if !self.is_exhausted() {
158 return Err(
159 if self.imms.len() != 0 {
161 Error::TooManyImms
162 } else {
163 assert!(self.ids.len() != 0);
165 Error::TooManyIds
166 },
167 );
168 }
169
170 Ok(())
171 }
172}
173
174pub struct ModuleEmitter {
175 pub words: Vec<u32>,
178}
179
180fn invalid(reason: &str) -> io::Error {
182 io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})"))
183}
184
185impl ModuleEmitter {
186 pub fn with_header(header: [u32; spec::HEADER_LEN]) -> Self {
187 Self { words: header.into() }
189 }
190
191 pub fn push_inst(&mut self, inst: &spv::InstWithIds) -> io::Result<()> {
193 let (inst_name, def) = inst.opcode.name_and_def();
194 let invalid = |msg: &str| invalid(&format!("in {inst_name}: {msg}"));
195
196 if inst.result_type_id.is_some() != def.has_result_type_id {
198 return Err(invalid("result type ID (`IdResultType`) mismatch"));
199 }
200 if inst.result_id.is_some() != def.has_result_id {
201 return Err(invalid("result ID (`IdResult`) mismatch"));
202 }
203
204 let total_word_count = 1
205 + (inst.result_type_id.is_some() as usize)
206 + (inst.result_id.is_some() as usize)
207 + inst.imms.len()
208 + inst.ids.len();
209
210 self.words.reserve(total_word_count);
211 let expected_final_pos = self.words.len() + total_word_count;
212
213 let opcode = u32::from(inst.opcode.as_u16())
214 | u32::from(u16::try_from(total_word_count).ok().ok_or_else(|| {
215 invalid("word count of SPIR-V instruction doesn't fit in 16 bits")
216 })?) << 16;
217 self.words.extend(
218 iter::once(opcode)
219 .chain(inst.result_type_id.map(|id| id.get()))
220 .chain(inst.result_id.map(|id| id.get())),
221 );
222
223 OperandEmitter {
224 imms: inst.imms.iter().copied(),
225 ids: inst.ids.iter().copied(),
226 out: &mut self.words,
227 }
228 .inst_operands(def)
229 .map_err(|e| invalid(&e.message()))?;
230
231 assert_eq!(self.words.len(), expected_final_pos);
234
235 Ok(())
236 }
237
238 pub fn write_to_spv_file(&self, path: impl AsRef<Path>) -> io::Result<()> {
239 fs::write(path, bytemuck::cast_slice::<u32, u8>(&self.words))
240 }
241}