spirt/spv/
write.rs

1//! Low-level emission of SPIR-V binary form.
2
3use crate::spv::{self, spec};
4use std::borrow::Cow;
5use std::path::Path;
6use std::{fs, io, iter, slice};
7
8// FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything.
9struct OperandEmitter<'a> {
10    /// Input immediate operands of an instruction.
11    imms: iter::Copied<slice::Iter<'a, spv::Imm>>,
12
13    /// Input ID operands of an instruction.
14    ids: iter::Copied<slice::Iter<'a, spv::Id>>,
15
16    /// Output SPIR-V words.
17    out: &'a mut Vec<u32>,
18}
19
20enum OperandEmitError {
21    /// Ran out of immediates while emitting an instruction's operands.
22    NotEnoughImms,
23
24    /// Ran out of IDs while emitting an instruction's operands.
25    NotEnoughIds,
26
27    /// Extra immediates were left over, after emitting an instruction's operands.
28    TooManyImms,
29
30    /// Extra IDs were left over, after emitting an instruction's operands.
31    TooManyIds,
32
33    /// Unsupported enumerand value.
34    UnsupportedEnumerand(spec::OperandKind, u32),
35}
36
37impl OperandEmitError {
38    // FIXME(eddyb) improve messages and add more contextual information.
39    fn message(&self) -> Cow<'static, str> {
40        match *self {
41            Self::NotEnoughImms => "truncated instruction (immediates)".into(),
42            Self::NotEnoughIds => "truncated instruction (IDs)".into(),
43            Self::TooManyImms => "overlong instruction (immediates)".into(),
44            Self::TooManyIds => "overlong instruction (IDs)".into(),
45            // FIXME(eddyb) deduplicate this with `spv::read`.
46            Self::UnsupportedEnumerand(kind, word) => {
47                let (name, def) = kind.name_and_def();
48                match def {
49                    spec::OperandKindDef::BitEnum { bits, .. } => {
50                        let unsupported = spec::BitIdx::of_all_set_bits(word)
51                            .filter(|&bit_idx| bits.get(bit_idx).is_none())
52                            .fold(0u32, |x, i| x | (1 << i.0));
53                        format!("unsupported {name} bit-pattern 0x{unsupported:08x}").into()
54                    }
55
56                    spec::OperandKindDef::ValueEnum { .. } => {
57                        format!("unsupported {name} value {word}").into()
58                    }
59
60                    _ => unreachable!(),
61                }
62            }
63        }
64    }
65}
66
67impl OperandEmitter<'_> {
68    fn is_exhausted(&mut self) -> bool {
69        // FIXME(eddyb) use `self.imms.is_empty() && self.ids.is_empty()` when
70        // that is stabilized.
71        self.imms.len() == 0 && self.ids.len() == 0
72    }
73
74    fn enumerant_params(&mut self, enumerant: &spec::Enumerant) -> Result<(), OperandEmitError> {
75        for (mode, kind) in enumerant.all_params() {
76            if mode == spec::OperandMode::Optional && self.is_exhausted() {
77                break;
78            }
79            self.operand(kind)?;
80        }
81
82        Ok(())
83    }
84
85    fn operand(&mut self, kind: spec::OperandKind) -> Result<(), OperandEmitError> {
86        use OperandEmitError as Error;
87
88        let mut get_enum_word = || match self.imms.next() {
89            Some(spv::Imm::Short(found_kind, word)) => {
90                assert_eq!(kind, found_kind);
91                Ok(word)
92            }
93            Some(spv::Imm::LongStart(..) | spv::Imm::LongCont(..)) => unreachable!(),
94            None => Err(Error::NotEnoughImms),
95        };
96
97        match kind.def() {
98            spec::OperandKindDef::BitEnum { bits, .. } => {
99                let word = get_enum_word()?;
100                self.out.push(word);
101
102                for bit_idx in spec::BitIdx::of_all_set_bits(word) {
103                    let bit_def =
104                        bits.get(bit_idx).ok_or(Error::UnsupportedEnumerand(kind, word))?;
105                    self.enumerant_params(bit_def)?;
106                }
107            }
108            spec::OperandKindDef::ValueEnum { variants } => {
109                let word = get_enum_word()?;
110                self.out.push(word);
111
112                let variant_def = u16::try_from(word)
113                    .ok()
114                    .and_then(|v| variants.get(v))
115                    .ok_or(Error::UnsupportedEnumerand(kind, word))?;
116                self.enumerant_params(variant_def)?;
117            }
118            spec::OperandKindDef::Id => {
119                self.out.push(self.ids.next().ok_or(Error::NotEnoughIds)?.get());
120            }
121            spec::OperandKindDef::Literal { .. } => {
122                match self.imms.next().ok_or(Error::NotEnoughImms)? {
123                    spv::Imm::Short(found_kind, word) => {
124                        assert_eq!(kind, found_kind);
125                        self.out.push(word);
126                    }
127                    spv::Imm::LongStart(found_kind, word) => {
128                        assert_eq!(kind, found_kind);
129                        self.out.push(word);
130                        while let Some(spv::Imm::LongCont(cont_kind, word)) =
131                            self.imms.clone().next()
132                        {
133                            self.imms.next();
134                            assert_eq!(kind, cont_kind);
135                            self.out.push(word);
136                        }
137                    }
138                    spv::Imm::LongCont(..) => unreachable!(),
139                }
140            }
141        }
142
143        Ok(())
144    }
145
146    fn inst_operands(mut self, def: &spec::InstructionDef) -> Result<(), OperandEmitError> {
147        use OperandEmitError as Error;
148
149        for (mode, kind) in def.all_operands() {
150            if mode == spec::OperandMode::Optional && self.is_exhausted() {
151                break;
152            }
153            self.operand(kind)?;
154        }
155
156        // The instruction must consume all of its operands.
157        if !self.is_exhausted() {
158            return Err(
159                // FIXME(eddyb) use `!self.imms.is_empty()` when that is stabilized.
160                if self.imms.len() != 0 {
161                    Error::TooManyImms
162                } else {
163                    // FIXME(eddyb) use `!self.ids.is_empty()` when that is stabilized.
164                    assert!(self.ids.len() != 0);
165                    Error::TooManyIds
166                },
167            );
168        }
169
170        Ok(())
171    }
172}
173
174pub struct ModuleEmitter {
175    /// Output SPIR-V words.
176    // FIXME(eddyb) try to write bytes to an `impl io::Write` directly.
177    pub words: Vec<u32>,
178}
179
180// FIXME(eddyb) stop abusing `io::Error` for error reporting.
181fn invalid(reason: &str) -> io::Error {
182    io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})"))
183}
184
185impl ModuleEmitter {
186    pub fn with_header(header: [u32; spec::HEADER_LEN]) -> Self {
187        // FIXME(eddyb) sanity-check the provided header words.
188        Self { words: header.into() }
189    }
190
191    // FIXME(eddyb) sanity-check the operands against the definition of `inst.opcode`.
192    pub fn push_inst(&mut self, inst: &spv::InstWithIds) -> io::Result<()> {
193        let (inst_name, def) = inst.opcode.name_and_def();
194        let invalid = |msg: &str| invalid(&format!("in {inst_name}: {msg}"));
195
196        // FIXME(eddyb) make these errors clearer (or turn them into asserts?).
197        if inst.result_type_id.is_some() != def.has_result_type_id {
198            return Err(invalid("result type ID (`IdResultType`) mismatch"));
199        }
200        if inst.result_id.is_some() != def.has_result_id {
201            return Err(invalid("result ID (`IdResult`) mismatch"));
202        }
203
204        let total_word_count = 1
205            + (inst.result_type_id.is_some() as usize)
206            + (inst.result_id.is_some() as usize)
207            + inst.imms.len()
208            + inst.ids.len();
209
210        self.words.reserve(total_word_count);
211        let expected_final_pos = self.words.len() + total_word_count;
212
213        let opcode = u32::from(inst.opcode.as_u16())
214            | u32::from(u16::try_from(total_word_count).ok().ok_or_else(|| {
215                invalid("word count of SPIR-V instruction doesn't fit in 16 bits")
216            })?) << 16;
217        self.words.extend(
218            iter::once(opcode)
219                .chain(inst.result_type_id.map(|id| id.get()))
220                .chain(inst.result_id.map(|id| id.get())),
221        );
222
223        OperandEmitter {
224            imms: inst.imms.iter().copied(),
225            ids: inst.ids.iter().copied(),
226            out: &mut self.words,
227        }
228        .inst_operands(def)
229        .map_err(|e| invalid(&e.message()))?;
230
231        // If no error was produced so far, `OperandEmitter` should've pushed
232        // the exact number of words.
233        assert_eq!(self.words.len(), expected_final_pos);
234
235        Ok(())
236    }
237
238    pub fn write_to_spv_file(&self, path: impl AsRef<Path>) -> io::Result<()> {
239        fs::write(path, bytemuck::cast_slice::<u32, u8>(&self.words))
240    }
241}