rspirv/dr/
loader.rs

1use crate::binary;
2use crate::dr;
3use crate::grammar;
4use crate::spirv;
5
6use crate::binary::{ParseAction, ParseResult};
7use std::{borrow::Cow, error, fmt};
8
9/// Data representation loading errors.
10#[derive(Debug)]
11pub enum Error {
12    NestedFunction,
13    UnclosedFunction,
14    MismatchedFunctionEnd,
15    DetachedFunctionParameter,
16    DetachedBlock,
17    NestedBlock,
18    UnclosedBlock,
19    MismatchedTerminator,
20    DetachedInstruction(Option<dr::Instruction>),
21    EmptyInstructionList,
22    WrongOpCapabilityOperand,
23    WrongOpExtensionOperand,
24    WrongOpExtInstImportOperand,
25    WrongOpMemoryModelOperand,
26    WrongOpNameOperand,
27    FunctionNotFound,
28    BlockNotFound,
29}
30
31impl Error {
32    /// Gives an descriptive string for each error.
33    ///
34    /// This method is intended to be used by fmt::Display and error::Error to
35    /// avoid duplication in implementation. So it's private.
36    fn describe(&self) -> Cow<'static, str> {
37        match self {
38            Error::NestedFunction => Cow::Borrowed("found nested function"),
39            Error::UnclosedFunction => Cow::Borrowed("found unclosed function"),
40            Error::MismatchedFunctionEnd => Cow::Borrowed("found mismatched OpFunctionEnd"),
41            Error::DetachedFunctionParameter => {
42                Cow::Borrowed("found function OpFunctionParameter not inside function")
43            }
44            Error::DetachedBlock => Cow::Borrowed("found block not inside function"),
45            Error::NestedBlock => Cow::Borrowed("found nested block"),
46            Error::UnclosedBlock => Cow::Borrowed("found block without terminator"),
47            Error::MismatchedTerminator => Cow::Borrowed("found mismatched terminator"),
48            Error::DetachedInstruction(Some(inst)) => Cow::Owned(format!(
49                "found instruction `{:?}` not inside block",
50                inst.class.opname
51            )),
52            Error::DetachedInstruction(None) => {
53                Cow::Borrowed("found unknown instruction not inside block")
54            }
55            Error::EmptyInstructionList => Cow::Borrowed("list of instructions is empty"),
56            Error::WrongOpCapabilityOperand => Cow::Borrowed("wrong OpCapability operand"),
57            Error::WrongOpExtensionOperand => Cow::Borrowed("wrong OpExtension operand"),
58            Error::WrongOpExtInstImportOperand => Cow::Borrowed("wrong OpExtInstImport operand"),
59            Error::WrongOpMemoryModelOperand => Cow::Borrowed("wrong OpMemoryModel operand"),
60            Error::WrongOpNameOperand => Cow::Borrowed("wrong OpName operand"),
61            Error::FunctionNotFound => Cow::Borrowed("can't find the function"),
62            Error::BlockNotFound => Cow::Borrowed("can't find the block"),
63        }
64    }
65}
66
67impl error::Error for Error {}
68
69impl fmt::Display for Error {
70    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        write!(f, "{}", self.describe())
72    }
73}
74
75/// The data representation loader.
76///
77/// Constructs a [`Module`](struct.Module.html) from the module header and
78/// instructions.
79///
80/// It implements the [`Consumer`](../binary/trait.Consumer.html) trait and
81/// works with the [`Parser`](../binary/struct.Parser.html).
82#[derive(Default)]
83pub struct Loader {
84    module: dr::Module,
85    function: Option<dr::Function>,
86    block: Option<dr::Block>,
87}
88
89impl Loader {
90    /// Creates a new empty loader.
91    pub fn new() -> Loader {
92        Loader {
93            module: dr::Module::new(),
94            function: None,
95            block: None,
96        }
97    }
98
99    /// Returns the `Module` under construction.
100    pub fn module(self) -> dr::Module {
101        self.module
102    }
103}
104
105/// Returns `$error` if `$condition` evaluates to false.
106macro_rules! if_ret_err {
107    ($condition: expr, $error: ident) => {
108        if $condition {
109            return ParseAction::Error(Box::new(Error::$error));
110        }
111    };
112}
113
114impl binary::Consumer for Loader {
115    fn initialize(&mut self) -> ParseAction {
116        ParseAction::Continue
117    }
118
119    fn finalize(&mut self) -> ParseAction {
120        if_ret_err!(self.block.is_some(), UnclosedBlock);
121        if_ret_err!(self.function.is_some(), UnclosedFunction);
122        ParseAction::Continue
123    }
124
125    fn consume_header(&mut self, header: dr::ModuleHeader) -> ParseAction {
126        self.module.header = Some(header);
127        ParseAction::Continue
128    }
129
130    fn consume_instruction(&mut self, inst: dr::Instruction) -> ParseAction {
131        let opcode = inst.class.opcode;
132        match opcode {
133            spirv::Op::Capability => self.module.capabilities.push(inst),
134            spirv::Op::Extension => self.module.extensions.push(inst),
135            spirv::Op::ExtInstImport => self.module.ext_inst_imports.push(inst),
136            spirv::Op::MemoryModel => self.module.memory_model = Some(inst),
137            spirv::Op::EntryPoint => self.module.entry_points.push(inst),
138            spirv::Op::ExecutionMode => self.module.execution_modes.push(inst),
139            spirv::Op::String
140            | spirv::Op::SourceExtension
141            | spirv::Op::Source
142            | spirv::Op::SourceContinued => self.module.debug_string_source.push(inst),
143            spirv::Op::Name | spirv::Op::MemberName => self.module.debug_names.push(inst),
144            spirv::Op::ModuleProcessed => self.module.debug_module_processed.push(inst),
145            opcode if grammar::reflect::is_location_debug(opcode) => {
146                match &mut self.block {
147                    Some(block) => block.instructions.push(inst),
148                    // types_global_values is the only valid section (other than functions) that
149                    // OpLine/OpNoLine can be placed in, so put it there.
150                    None => self.module.types_global_values.push(inst),
151                }
152            }
153            opcode if grammar::reflect::is_annotation(opcode) => self.module.annotations.push(inst),
154            opcode
155                if grammar::reflect::is_type(opcode) || grammar::reflect::is_constant(opcode) =>
156            {
157                self.module.types_global_values.push(inst)
158            }
159            spirv::Op::Variable if self.function.is_none() => {
160                self.module.types_global_values.push(inst)
161            }
162            spirv::Op::Undef if self.function.is_none() => {
163                self.module.types_global_values.push(inst)
164            }
165            spirv::Op::Function => {
166                if_ret_err!(self.function.is_some(), NestedFunction);
167                let mut f = dr::Function::new();
168                f.def = Some(inst);
169                self.function = Some(f)
170            }
171            spirv::Op::FunctionEnd => {
172                if_ret_err!(self.function.is_none(), MismatchedFunctionEnd);
173                if_ret_err!(self.block.is_some(), UnclosedBlock);
174                self.function.as_mut().unwrap().end = Some(inst);
175                self.module.functions.push(self.function.take().unwrap())
176            }
177            spirv::Op::FunctionParameter => {
178                if_ret_err!(self.function.is_none(), DetachedFunctionParameter);
179                self.function.as_mut().unwrap().parameters.push(inst);
180            }
181            spirv::Op::Label => {
182                if_ret_err!(self.function.is_none(), DetachedBlock);
183                if_ret_err!(self.block.is_some(), NestedBlock);
184                let mut block = dr::Block::new();
185                block.label = Some(inst);
186                self.block = Some(block)
187            }
188            opcode if grammar::reflect::is_block_terminator(opcode) => {
189                // Make sure the block exists here. Once the block exists,
190                // we are certain the function exists because the above checks.
191                if_ret_err!(self.block.is_none(), MismatchedTerminator);
192                self.block.as_mut().unwrap().instructions.push(inst);
193                self.function
194                    .as_mut()
195                    .unwrap()
196                    .blocks
197                    .push(self.block.take().unwrap())
198            }
199            _ => {
200                if self.block.is_none() {
201                    return ParseAction::Error(Box::new(Error::DetachedInstruction(Some(inst))));
202                }
203                self.block.as_mut().unwrap().instructions.push(inst)
204            }
205        }
206        ParseAction::Continue
207    }
208}
209
210/// Loads the SPIR-V `binary` into memory and returns a `Module`.
211///
212/// # Examples
213///
214/// ```
215/// use rspirv;
216/// use rspirv::binary::Disassemble;
217///
218/// let buffer: Vec<u8> = vec![
219///     // Magic number.           Version number: 1.0.
220///     0x03, 0x02, 0x23, 0x07,    0x00, 0x00, 0x01, 0x00,
221///     // Generator number: 0.    Bound: 0.
222///     0x00, 0x00, 0x00, 0x00,    0x00, 0x00, 0x00, 0x00,
223///     // Reserved word: 0.
224///     0x00, 0x00, 0x00, 0x00,
225///     // OpMemoryModel.          Logical.
226///     0x0e, 0x00, 0x03, 0x00,    0x00, 0x00, 0x00, 0x00,
227///     // GLSL450.
228///     0x01, 0x00, 0x00, 0x00];
229///
230/// let dis = match rspirv::dr::load_bytes(buffer) {
231///     Ok(module) => module.disassemble(),
232///     Err(err) => format!("{}", err),
233/// };
234///
235/// assert_eq!(dis,
236///            "; SPIR-V\n\
237///             ; Version: 1.0\n\
238///             ; Generator: rspirv\n\
239///             ; Bound: 0\n\
240///             OpMemoryModel Logical GLSL450");
241/// ```
242pub fn load_bytes(binary: impl AsRef<[u8]>) -> ParseResult<dr::Module> {
243    let mut loader = Loader::new();
244    binary::parse_bytes(binary, &mut loader)?;
245    Ok(loader.module())
246}
247
248/// Loads the SPIR-V `binary` into memory and returns a `Module`.
249///
250/// # Examples
251///
252/// ```
253/// use rspirv;
254/// use rspirv::binary::Disassemble;
255///
256/// let buffer: Vec<u32> = vec![
257///     0x07230203,  // Magic number
258///     0x00010000,  // Version number: 1.0
259///     0x00000000,  // Generator number: 0
260///     0x00000000,  // Bound: 0
261///     0x00000000,  // Reserved word: 0
262///     0x0003000e,  // OpMemoryModel
263///     0x00000000,  // Logical
264///     0x00000001,  // GLSL450
265/// ];
266///
267/// let dis = match rspirv::dr::load_words(buffer) {
268///     Ok(module) => module.disassemble(),
269///     Err(err) => format!("{}", err),
270/// };
271///
272/// assert_eq!(dis,
273///            "; SPIR-V\n\
274///             ; Version: 1.0\n\
275///             ; Generator: rspirv\n\
276///             ; Bound: 0\n\
277///             OpMemoryModel Logical GLSL450");
278/// ```
279pub fn load_words(binary: impl AsRef<[u32]>) -> ParseResult<dr::Module> {
280    let mut loader = Loader::new();
281    binary::parse_words(binary, &mut loader)?;
282    Ok(loader.module())
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::dr;
288    use crate::spirv;
289
290    #[test]
291    fn test_load_variable() {
292        let mut b = dr::Builder::new();
293
294        let void = b.type_void();
295        let float = b.type_float(32);
296        let voidfvoid = b.type_function(void, vec![void]);
297
298        // Global variable
299        let global = b.variable(float, None, spirv::StorageClass::Input, None);
300
301        b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
302            .unwrap();
303        b.begin_block(None).unwrap();
304        // Local variable
305        let local = b.variable(float, None, spirv::StorageClass::Function, None);
306        b.ret().unwrap();
307        b.end_function().unwrap();
308
309        let m = b.module();
310
311        assert_eq!(m.types_global_values.len(), 4);
312        let inst = &m.types_global_values[3];
313        assert_eq!(inst.class.opcode, spirv::Op::Variable);
314        assert_eq!(inst.result_id.unwrap(), global);
315
316        assert_eq!(m.functions.len(), 1);
317        let f = &m.functions[0];
318        assert_eq!(f.blocks.len(), 1);
319        let bb = &f.blocks[0];
320        assert!(bb.instructions.len() > 1);
321        let inst = &bb.instructions[0];
322        assert_eq!(inst.class.opcode, spirv::Op::Variable);
323        assert_eq!(inst.result_id.unwrap(), local);
324    }
325
326    #[test]
327    fn test_load_undef() {
328        let mut b = dr::Builder::new();
329
330        let void = b.type_void();
331        let float = b.type_float(32);
332        let voidfvoid = b.type_function(void, vec![void]);
333
334        // Global variable
335        let global = b.undef(float, None);
336
337        b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
338            .unwrap();
339        b.begin_block(None).unwrap();
340        // Local variable
341        let local = b.undef(float, None);
342        b.ret().unwrap();
343        b.end_function().unwrap();
344
345        let m = b.module();
346
347        assert_eq!(m.types_global_values.len(), 4);
348        let inst = &m.types_global_values[3];
349        assert_eq!(inst.class.opcode, spirv::Op::Undef);
350        assert_eq!(inst.result_id.unwrap(), global);
351
352        assert_eq!(m.functions.len(), 1);
353        let f = &m.functions[0];
354        assert_eq!(f.blocks.len(), 1);
355        let bb = &f.blocks[0];
356        assert!(bb.instructions.len() > 1);
357        let inst = &bb.instructions[0];
358        assert_eq!(inst.class.opcode, spirv::Op::Undef);
359        assert_eq!(inst.result_id.unwrap(), local);
360    }
361}