1use crate::spv::{self, spec};
4use rustc_hash::FxHashMap;
5use smallvec::SmallVec;
6use std::borrow::Cow;
7use std::num::NonZeroU32;
8use std::path::Path;
9use std::{fs, io, iter, slice};
10
11enum KnownIdDef {
15 TypeInt(NonZeroU32),
16 TypeFloat(NonZeroU32),
17 Uncategorized { opcode: spec::Opcode, result_type_id: Option<spv::Id> },
18}
19
20impl KnownIdDef {
21 fn result_type_id(&self) -> Option<spv::Id> {
22 match *self {
23 Self::TypeInt(_) | Self::TypeFloat(_) => None,
24 Self::Uncategorized { result_type_id, .. } => result_type_id,
25 }
26 }
27}
28
29struct InstParser<'a> {
31 known_ids: &'a FxHashMap<spv::Id, KnownIdDef>,
33
34 words: iter::Copied<slice::Iter<'a, u32>>,
36
37 inst: spv::InstWithIds,
39}
40
41enum InstParseError {
42 NotEnoughWords,
44
45 TooManyWords,
47
48 IdZero,
50
51 UnsupportedEnumerand(spec::OperandKind, u32),
53
54 UnknownResultTypeId(spv::Id),
56
57 MissingContextSensitiveLiteralType,
59
60 UnsupportedContextSensitiveLiteralType { type_opcode: spec::Opcode },
63}
64
65impl InstParseError {
66 fn message(&self) -> Cow<'static, str> {
68 match *self {
69 Self::NotEnoughWords => "truncated instruction".into(),
70 Self::TooManyWords => "overlong instruction".into(),
71 Self::IdZero => "ID %0 is illegal".into(),
72 Self::UnsupportedEnumerand(kind, word) => {
74 let (name, def) = kind.name_and_def();
75 match def {
76 spec::OperandKindDef::BitEnum { bits, .. } => {
77 let unsupported = spec::BitIdx::of_all_set_bits(word)
78 .filter(|&bit_idx| bits.get(bit_idx).is_none())
79 .fold(0u32, |x, i| x | (1 << i.0));
80 format!("unsupported {name} bit-pattern 0x{unsupported:08x}").into()
81 }
82
83 spec::OperandKindDef::ValueEnum { .. } => {
84 format!("unsupported {name} value {word}").into()
85 }
86
87 _ => unreachable!(),
88 }
89 }
90 Self::UnknownResultTypeId(id) => {
91 format!("ID %{id} used as result type before definition").into()
92 }
93 Self::MissingContextSensitiveLiteralType => "missing type for literal".into(),
94 Self::UnsupportedContextSensitiveLiteralType { type_opcode } => {
95 format!("{} is not a supported literal type", type_opcode.name()).into()
96 }
97 }
98 }
99}
100
101impl InstParser<'_> {
102 fn is_exhausted(&self) -> bool {
103 self.words.len() == 0
105 }
106
107 fn enumerant_params(&mut self, enumerant: &spec::Enumerant) -> Result<(), InstParseError> {
108 for (mode, kind) in enumerant.all_params() {
109 if mode == spec::OperandMode::Optional && self.is_exhausted() {
110 break;
111 }
112 self.operand(kind)?;
113 }
114
115 Ok(())
116 }
117
118 fn operand(&mut self, kind: spec::OperandKind) -> Result<(), InstParseError> {
119 use InstParseError as Error;
120
121 let word = self.words.next().ok_or(Error::NotEnoughWords)?;
122 match kind.def() {
123 spec::OperandKindDef::BitEnum { bits, .. } => {
124 self.inst.imms.push(spv::Imm::Short(kind, word));
125
126 for bit_idx in spec::BitIdx::of_all_set_bits(word) {
127 let bit_def =
128 bits.get(bit_idx).ok_or(Error::UnsupportedEnumerand(kind, word))?;
129 self.enumerant_params(bit_def)?;
130 }
131 }
132
133 spec::OperandKindDef::ValueEnum { variants } => {
134 self.inst.imms.push(spv::Imm::Short(kind, word));
135
136 let variant_def = u16::try_from(word)
137 .ok()
138 .and_then(|v| variants.get(v))
139 .ok_or(Error::UnsupportedEnumerand(kind, word))?;
140 self.enumerant_params(variant_def)?;
141 }
142
143 spec::OperandKindDef::Id => {
144 let id = word.try_into().ok().ok_or(Error::IdZero)?;
145 self.inst.ids.push(id);
146 }
147
148 spec::OperandKindDef::Literal { size: spec::LiteralSize::Word } => {
149 self.inst.imms.push(spv::Imm::Short(kind, word));
150 }
151 spec::OperandKindDef::Literal { size: spec::LiteralSize::NulTerminated } => {
152 let has_nul = |word: u32| word.to_le_bytes().contains(&0);
153 if has_nul(word) {
154 self.inst.imms.push(spv::Imm::Short(kind, word));
155 } else {
156 self.inst.imms.push(spv::Imm::LongStart(kind, word));
157 for word in &mut self.words {
158 self.inst.imms.push(spv::Imm::LongCont(kind, word));
159 if has_nul(word) {
160 break;
161 }
162 }
163 }
164 }
165 spec::OperandKindDef::Literal { size: spec::LiteralSize::FromContextualType } => {
166 let contextual_type = self
167 .inst
168 .result_type_id
169 .or_else(|| {
170 let &id = self.inst.ids.first()?;
172 self.known_ids.get(&id)?.result_type_id()
173 })
174 .and_then(|id| self.known_ids.get(&id))
175 .ok_or(Error::MissingContextSensitiveLiteralType)?;
176
177 let extra_word_count = match *contextual_type {
178 KnownIdDef::TypeInt(width) | KnownIdDef::TypeFloat(width) => {
179 (width.get() - 1) / 32
181 }
182 KnownIdDef::Uncategorized { opcode, .. } => {
183 return Err(Error::UnsupportedContextSensitiveLiteralType {
184 type_opcode: opcode,
185 });
186 }
187 };
188
189 if extra_word_count == 0 {
190 self.inst.imms.push(spv::Imm::Short(kind, word));
191 } else {
192 self.inst.imms.push(spv::Imm::LongStart(kind, word));
193 for _ in 0..extra_word_count {
194 let word = self.words.next().ok_or(Error::NotEnoughWords)?;
195 self.inst.imms.push(spv::Imm::LongCont(kind, word));
196 }
197 }
198 }
199 }
200
201 Ok(())
202 }
203
204 fn inst(mut self, def: &spec::InstructionDef) -> Result<spv::InstWithIds, InstParseError> {
205 use InstParseError as Error;
206
207 {
208 let mut id = || {
210 self.words.next().ok_or(Error::NotEnoughWords)?.try_into().ok().ok_or(Error::IdZero)
211 };
212 self.inst.result_type_id = def.has_result_type_id.then(&mut id).transpose()?;
213 self.inst.result_id = def.has_result_id.then(&mut id).transpose()?;
214 }
215
216 if let Some(type_id) = self.inst.result_type_id {
217 if !self.known_ids.contains_key(&type_id) {
218 return Err(Error::UnknownResultTypeId(type_id));
220 }
221 }
222
223 for (mode, kind) in def.all_operands() {
224 if mode == spec::OperandMode::Optional && self.is_exhausted() {
225 break;
226 }
227 self.operand(kind)?;
228 }
229
230 if !self.is_exhausted() {
232 return Err(Error::TooManyWords);
233 }
234
235 Ok(self.inst)
236 }
237}
238
239pub struct ModuleParser {
240 pub header: [u32; spec::HEADER_LEN],
243
244 word_bytes: Vec<u8>,
247
248 next_word: usize,
250
251 known_ids: FxHashMap<spv::Id, KnownIdDef>,
253}
254
255fn invalid(reason: &str) -> io::Error {
257 io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})"))
258}
259
260impl ModuleParser {
261 pub fn read_from_spv_file(path: impl AsRef<Path>) -> io::Result<Self> {
262 Self::read_from_spv_bytes(fs::read(path)?)
263 }
264
265 pub fn read_from_spv_bytes(spv_bytes: Vec<u8>) -> io::Result<Self> {
267 let spv_spec = spec::Spec::get();
268
269 if spv_bytes.len() % 4 != 0 {
270 return Err(invalid("not a multiple of 4 bytes"));
271 }
272 let mut spv_bytes = spv_bytes;
274 let spv_words = bytemuck::cast_slice_mut::<u8, u32>(&mut spv_bytes);
275
276 if spv_words.len() < spec::HEADER_LEN {
277 return Err(invalid("truncated header"));
278 }
279
280 {
282 let magic = spv_words[0];
283 if magic == spv_spec.magic {
284 } else if magic.swap_bytes() == spv_spec.magic {
286 for word in &mut spv_words[..] {
287 *word = word.swap_bytes();
288 }
289 } else {
290 return Err(invalid("incorrect magic number"));
291 }
292 }
293
294 Ok(Self {
295 header: spv_words[..spec::HEADER_LEN].try_into().unwrap(),
296 word_bytes: spv_bytes,
297 next_word: spec::HEADER_LEN,
298
299 known_ids: FxHashMap::default(),
300 })
301 }
302}
303
304impl Iterator for ModuleParser {
305 type Item = io::Result<spv::InstWithIds>;
306 fn next(&mut self) -> Option<Self::Item> {
307 let spv_spec = spec::Spec::get();
308 let wk = &spv_spec.well_known;
309
310 let words = &bytemuck::cast_slice::<u8, u32>(&self.word_bytes)[self.next_word..];
311 let &opcode = words.first()?;
312
313 let (inst_len, opcode) = ((opcode >> 16) as usize, opcode as u16);
314
315 let (opcode, inst_name, def) = match spec::Opcode::try_from_u16_with_name_and_def(opcode) {
316 Some(opcode_name_and_def) => opcode_name_and_def,
317 None => return Some(Err(invalid(&format!("unsupported opcode {opcode}")))),
318 };
319
320 let invalid = |msg: &str| invalid(&format!("in {inst_name}: {msg}"));
321
322 if words.len() < inst_len {
323 return Some(Err(invalid("truncated instruction")));
324 }
325
326 let parser = InstParser {
327 known_ids: &self.known_ids,
328 words: words[1..inst_len].iter().copied(),
329 inst: spv::InstWithIds {
330 without_ids: opcode.into(),
331 result_type_id: None,
332 result_id: None,
333 ids: SmallVec::new(),
334 },
335 };
336
337 let inst = match parser.inst(def) {
338 Ok(inst) => inst,
339 Err(e) => return Some(Err(invalid(&e.message()))),
340 };
341
342 let maybe_known_id_result = inst.result_id.map(|id| {
344 let known_id_def = if opcode == wk.OpTypeInt {
345 KnownIdDef::TypeInt(match inst.imms[0] {
346 spv::Imm::Short(kind, n) => {
347 assert_eq!(kind, wk.LiteralInteger);
348 n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?
349 }
350 _ => unreachable!(),
351 })
352 } else if opcode == wk.OpTypeFloat {
353 KnownIdDef::TypeFloat(match inst.imms[0] {
354 spv::Imm::Short(kind, n) => {
355 assert_eq!(kind, wk.LiteralInteger);
356 n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?
357 }
358 _ => unreachable!(),
359 })
360 } else {
361 KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id }
362 };
363
364 let old = self.known_ids.insert(id, known_id_def);
365 if old.is_some() {
366 return Err(invalid(&format!("ID %{id} is a result of multiple instructions")));
367 }
368
369 Ok(())
370 });
371 if let Some(Err(e)) = maybe_known_id_result {
372 return Some(Err(e));
373 }
374
375 self.next_word += inst_len;
376
377 Some(Ok(inst))
378 }
379}