spirt/spv/
read.rs

1//! Low-level parsing of SPIR-V binary form.
2
3use crate::spv::{self, spec};
4use rustc_hash::FxHashMap;
5use smallvec::SmallVec;
6use std::borrow::Cow;
7use std::num::NonZeroU32;
8use std::path::Path;
9use std::{fs, io, iter, slice};
10
11/// Defining instruction of an ID.
12///
13/// Used currently only to help parsing `LiteralContextDependentNumber`.
14enum KnownIdDef {
15    TypeInt(NonZeroU32),
16    TypeFloat(NonZeroU32),
17    Uncategorized { opcode: spec::Opcode, result_type_id: Option<spv::Id> },
18}
19
20impl KnownIdDef {
21    fn result_type_id(&self) -> Option<spv::Id> {
22        match *self {
23            Self::TypeInt(_) | Self::TypeFloat(_) => None,
24            Self::Uncategorized { result_type_id, .. } => result_type_id,
25        }
26    }
27}
28
29// FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything.
30struct InstParser<'a> {
31    /// IDs defined so far in the module.
32    known_ids: &'a FxHashMap<spv::Id, KnownIdDef>,
33
34    /// Input words of an instruction.
35    words: iter::Copied<slice::Iter<'a, u32>>,
36
37    /// Output instruction, being parsed.
38    inst: spv::InstWithIds,
39}
40
41enum InstParseError {
42    /// Ran out of words while parsing an instruction's operands.
43    NotEnoughWords,
44
45    /// Extra words were left over, after parsing an instruction's operands.
46    TooManyWords,
47
48    /// An illegal ID of `0`.
49    IdZero,
50
51    /// Unsupported enumerand value.
52    UnsupportedEnumerand(spec::OperandKind, u32),
53
54    /// An `IdResultType` ID referring to an ID not already defined.
55    UnknownResultTypeId(spv::Id),
56
57    /// The type of a `LiteralContextDependentNumber` could not be determined.
58    MissingContextSensitiveLiteralType,
59
60    /// The type of a `LiteralContextDependentNumber` was not a supported type
61    /// (one of either `OpTypeInt` or `OpTypeFloat`).
62    UnsupportedContextSensitiveLiteralType { type_opcode: spec::Opcode },
63}
64
65impl InstParseError {
66    // FIXME(eddyb) improve messages and add more contextual information.
67    fn message(&self) -> Cow<'static, str> {
68        match *self {
69            Self::NotEnoughWords => "truncated instruction".into(),
70            Self::TooManyWords => "overlong instruction".into(),
71            Self::IdZero => "ID %0 is illegal".into(),
72            // FIXME(eddyb) deduplicate this with `spv::write`.
73            Self::UnsupportedEnumerand(kind, word) => {
74                let (name, def) = kind.name_and_def();
75                match def {
76                    spec::OperandKindDef::BitEnum { bits, .. } => {
77                        let unsupported = spec::BitIdx::of_all_set_bits(word)
78                            .filter(|&bit_idx| bits.get(bit_idx).is_none())
79                            .fold(0u32, |x, i| x | (1 << i.0));
80                        format!("unsupported {name} bit-pattern 0x{unsupported:08x}").into()
81                    }
82
83                    spec::OperandKindDef::ValueEnum { .. } => {
84                        format!("unsupported {name} value {word}").into()
85                    }
86
87                    _ => unreachable!(),
88                }
89            }
90            Self::UnknownResultTypeId(id) => {
91                format!("ID %{id} used as result type before definition").into()
92            }
93            Self::MissingContextSensitiveLiteralType => "missing type for literal".into(),
94            Self::UnsupportedContextSensitiveLiteralType { type_opcode } => {
95                format!("{} is not a supported literal type", type_opcode.name()).into()
96            }
97        }
98    }
99}
100
101impl InstParser<'_> {
102    fn is_exhausted(&self) -> bool {
103        // FIXME(eddyb) use `self.words.is_empty()` when that is stabilized.
104        self.words.len() == 0
105    }
106
107    fn enumerant_params(&mut self, enumerant: &spec::Enumerant) -> Result<(), InstParseError> {
108        for (mode, kind) in enumerant.all_params() {
109            if mode == spec::OperandMode::Optional && self.is_exhausted() {
110                break;
111            }
112            self.operand(kind)?;
113        }
114
115        Ok(())
116    }
117
118    fn operand(&mut self, kind: spec::OperandKind) -> Result<(), InstParseError> {
119        use InstParseError as Error;
120
121        let word = self.words.next().ok_or(Error::NotEnoughWords)?;
122        match kind.def() {
123            spec::OperandKindDef::BitEnum { bits, .. } => {
124                self.inst.imms.push(spv::Imm::Short(kind, word));
125
126                for bit_idx in spec::BitIdx::of_all_set_bits(word) {
127                    let bit_def =
128                        bits.get(bit_idx).ok_or(Error::UnsupportedEnumerand(kind, word))?;
129                    self.enumerant_params(bit_def)?;
130                }
131            }
132
133            spec::OperandKindDef::ValueEnum { variants } => {
134                self.inst.imms.push(spv::Imm::Short(kind, word));
135
136                let variant_def = u16::try_from(word)
137                    .ok()
138                    .and_then(|v| variants.get(v))
139                    .ok_or(Error::UnsupportedEnumerand(kind, word))?;
140                self.enumerant_params(variant_def)?;
141            }
142
143            spec::OperandKindDef::Id => {
144                let id = word.try_into().ok().ok_or(Error::IdZero)?;
145                self.inst.ids.push(id);
146            }
147
148            spec::OperandKindDef::Literal { size: spec::LiteralSize::Word } => {
149                self.inst.imms.push(spv::Imm::Short(kind, word));
150            }
151            spec::OperandKindDef::Literal { size: spec::LiteralSize::NulTerminated } => {
152                let has_nul = |word: u32| word.to_le_bytes().contains(&0);
153                if has_nul(word) {
154                    self.inst.imms.push(spv::Imm::Short(kind, word));
155                } else {
156                    self.inst.imms.push(spv::Imm::LongStart(kind, word));
157                    for word in &mut self.words {
158                        self.inst.imms.push(spv::Imm::LongCont(kind, word));
159                        if has_nul(word) {
160                            break;
161                        }
162                    }
163                }
164            }
165            spec::OperandKindDef::Literal { size: spec::LiteralSize::FromContextualType } => {
166                let contextual_type = self
167                    .inst
168                    .result_type_id
169                    .or_else(|| {
170                        // `OpSwitch` takes its literal type from the first operand.
171                        let &id = self.inst.ids.first()?;
172                        self.known_ids.get(&id)?.result_type_id()
173                    })
174                    .and_then(|id| self.known_ids.get(&id))
175                    .ok_or(Error::MissingContextSensitiveLiteralType)?;
176
177                let extra_word_count = match *contextual_type {
178                    KnownIdDef::TypeInt(width) | KnownIdDef::TypeFloat(width) => {
179                        // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow.
180                        (width.get() - 1) / 32
181                    }
182                    KnownIdDef::Uncategorized { opcode, .. } => {
183                        return Err(Error::UnsupportedContextSensitiveLiteralType {
184                            type_opcode: opcode,
185                        });
186                    }
187                };
188
189                if extra_word_count == 0 {
190                    self.inst.imms.push(spv::Imm::Short(kind, word));
191                } else {
192                    self.inst.imms.push(spv::Imm::LongStart(kind, word));
193                    for _ in 0..extra_word_count {
194                        let word = self.words.next().ok_or(Error::NotEnoughWords)?;
195                        self.inst.imms.push(spv::Imm::LongCont(kind, word));
196                    }
197                }
198            }
199        }
200
201        Ok(())
202    }
203
204    fn inst(mut self, def: &spec::InstructionDef) -> Result<spv::InstWithIds, InstParseError> {
205        use InstParseError as Error;
206
207        {
208            // FIXME(eddyb) should this be a method?
209            let mut id = || {
210                self.words.next().ok_or(Error::NotEnoughWords)?.try_into().ok().ok_or(Error::IdZero)
211            };
212            self.inst.result_type_id = def.has_result_type_id.then(&mut id).transpose()?;
213            self.inst.result_id = def.has_result_id.then(&mut id).transpose()?;
214        }
215
216        if let Some(type_id) = self.inst.result_type_id {
217            if !self.known_ids.contains_key(&type_id) {
218                // FIXME(eddyb) also check that the ID is a valid type.
219                return Err(Error::UnknownResultTypeId(type_id));
220            }
221        }
222
223        for (mode, kind) in def.all_operands() {
224            if mode == spec::OperandMode::Optional && self.is_exhausted() {
225                break;
226            }
227            self.operand(kind)?;
228        }
229
230        // The instruction must consume its entire word count.
231        if !self.is_exhausted() {
232            return Err(Error::TooManyWords);
233        }
234
235        Ok(self.inst)
236    }
237}
238
239pub struct ModuleParser {
240    /// Copy of the header words (for convenience).
241    // FIXME(eddyb) add a `spec::Header` or `spv::Header` struct with named fields.
242    pub header: [u32; spec::HEADER_LEN],
243
244    /// The entire module's bytes, representing "native endian" SPIR-V words.
245    // FIXME(eddyb) could this be allocated as `Vec<u32>` in the first place?
246    word_bytes: Vec<u8>,
247
248    /// Next (instructions') word position in the module.
249    next_word: usize,
250
251    /// IDs defined so far in the module.
252    known_ids: FxHashMap<spv::Id, KnownIdDef>,
253}
254
255// FIXME(eddyb) stop abusing `io::Error` for error reporting.
256fn invalid(reason: &str) -> io::Error {
257    io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})"))
258}
259
260impl ModuleParser {
261    pub fn read_from_spv_file(path: impl AsRef<Path>) -> io::Result<Self> {
262        Self::read_from_spv_bytes(fs::read(path)?)
263    }
264
265    // FIXME(eddyb) also add `from_spv_words`.
266    pub fn read_from_spv_bytes(spv_bytes: Vec<u8>) -> io::Result<Self> {
267        let spv_spec = spec::Spec::get();
268
269        if spv_bytes.len() % 4 != 0 {
270            return Err(invalid("not a multiple of 4 bytes"));
271        }
272        // May need to mutate the bytes (to normalize endianness) later below.
273        let mut spv_bytes = spv_bytes;
274        let spv_words = bytemuck::cast_slice_mut::<u8, u32>(&mut spv_bytes);
275
276        if spv_words.len() < spec::HEADER_LEN {
277            return Err(invalid("truncated header"));
278        }
279
280        // Check the magic, and swap endianness of all words if we have to.
281        {
282            let magic = spv_words[0];
283            if magic == spv_spec.magic {
284                // Nothing to do, all words already match native endianness.
285            } else if magic.swap_bytes() == spv_spec.magic {
286                for word in &mut spv_words[..] {
287                    *word = word.swap_bytes();
288                }
289            } else {
290                return Err(invalid("incorrect magic number"));
291            }
292        }
293
294        Ok(Self {
295            header: spv_words[..spec::HEADER_LEN].try_into().unwrap(),
296            word_bytes: spv_bytes,
297            next_word: spec::HEADER_LEN,
298
299            known_ids: FxHashMap::default(),
300        })
301    }
302}
303
304impl Iterator for ModuleParser {
305    type Item = io::Result<spv::InstWithIds>;
306    fn next(&mut self) -> Option<Self::Item> {
307        let spv_spec = spec::Spec::get();
308        let wk = &spv_spec.well_known;
309
310        let words = &bytemuck::cast_slice::<u8, u32>(&self.word_bytes)[self.next_word..];
311        let &opcode = words.first()?;
312
313        let (inst_len, opcode) = ((opcode >> 16) as usize, opcode as u16);
314
315        let (opcode, inst_name, def) = match spec::Opcode::try_from_u16_with_name_and_def(opcode) {
316            Some(opcode_name_and_def) => opcode_name_and_def,
317            None => return Some(Err(invalid(&format!("unsupported opcode {opcode}")))),
318        };
319
320        let invalid = |msg: &str| invalid(&format!("in {inst_name}: {msg}"));
321
322        if words.len() < inst_len {
323            return Some(Err(invalid("truncated instruction")));
324        }
325
326        let parser = InstParser {
327            known_ids: &self.known_ids,
328            words: words[1..inst_len].iter().copied(),
329            inst: spv::InstWithIds {
330                without_ids: opcode.into(),
331                result_type_id: None,
332                result_id: None,
333                ids: SmallVec::new(),
334            },
335        };
336
337        let inst = match parser.inst(def) {
338            Ok(inst) => inst,
339            Err(e) => return Some(Err(invalid(&e.message()))),
340        };
341
342        // HACK(eddyb) `Option::map` allows using `?` for `Result` in the closure.
343        let maybe_known_id_result = inst.result_id.map(|id| {
344            let known_id_def = if opcode == wk.OpTypeInt {
345                KnownIdDef::TypeInt(match inst.imms[0] {
346                    spv::Imm::Short(kind, n) => {
347                        assert_eq!(kind, wk.LiteralInteger);
348                        n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?
349                    }
350                    _ => unreachable!(),
351                })
352            } else if opcode == wk.OpTypeFloat {
353                KnownIdDef::TypeFloat(match inst.imms[0] {
354                    spv::Imm::Short(kind, n) => {
355                        assert_eq!(kind, wk.LiteralInteger);
356                        n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?
357                    }
358                    _ => unreachable!(),
359                })
360            } else {
361                KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id }
362            };
363
364            let old = self.known_ids.insert(id, known_id_def);
365            if old.is_some() {
366                return Err(invalid(&format!("ID %{id} is a result of multiple instructions")));
367            }
368
369            Ok(())
370        });
371        if let Some(Err(e)) = maybe_known_id_result {
372            return Some(Err(e));
373        }
374
375        self.next_word += inst_len;
376
377        Some(Ok(inst))
378    }
379}