Skip to main content

rspirv/dr/
loader.rs

1use crate::binary;
2use crate::dr;
3use crate::grammar;
4use crate::spirv;
5
6use crate::binary::{ParseAction, ParseResult};
7use std::{borrow::Cow, error, fmt};
8
9/// Data representation loading errors.
10#[derive(Debug)]
11pub enum Error {
12    NestedFunction,
13    UnclosedFunction,
14    MismatchedFunctionEnd,
15    DetachedFunctionParameter,
16    DetachedBlock,
17    NestedBlock,
18    UnclosedBlock,
19    MismatchedTerminator,
20    DetachedInstruction(Option<dr::Instruction>),
21    EmptyInstructionList,
22    WrongOpCapabilityOperand,
23    WrongOpExtensionOperand,
24    WrongOpExtInstImportOperand,
25    WrongOpMemoryModelOperand,
26    WrongOpNameOperand,
27    FunctionNotFound,
28    BlockNotFound,
29}
30
31impl Error {
32    /// Gives an descriptive string for each error.
33    ///
34    /// This method is intended to be used by fmt::Display and error::Error to
35    /// avoid duplication in implementation. So it's private.
36    fn describe(&self) -> Cow<'static, str> {
37        match self {
38            Error::NestedFunction => Cow::Borrowed("found nested function"),
39            Error::UnclosedFunction => Cow::Borrowed("found unclosed function"),
40            Error::MismatchedFunctionEnd => Cow::Borrowed("found mismatched OpFunctionEnd"),
41            Error::DetachedFunctionParameter => {
42                Cow::Borrowed("found function OpFunctionParameter not inside function")
43            }
44            Error::DetachedBlock => Cow::Borrowed("found block not inside function"),
45            Error::NestedBlock => Cow::Borrowed("found nested block"),
46            Error::UnclosedBlock => Cow::Borrowed("found block without terminator"),
47            Error::MismatchedTerminator => Cow::Borrowed("found mismatched terminator"),
48            Error::DetachedInstruction(Some(inst)) => Cow::Owned(format!(
49                "found instruction `{:?}` not inside block",
50                inst.class.opname
51            )),
52            Error::DetachedInstruction(None) => {
53                Cow::Borrowed("found unknown instruction not inside block")
54            }
55            Error::EmptyInstructionList => Cow::Borrowed("list of instructions is empty"),
56            Error::WrongOpCapabilityOperand => Cow::Borrowed("wrong OpCapability operand"),
57            Error::WrongOpExtensionOperand => Cow::Borrowed("wrong OpExtension operand"),
58            Error::WrongOpExtInstImportOperand => Cow::Borrowed("wrong OpExtInstImport operand"),
59            Error::WrongOpMemoryModelOperand => Cow::Borrowed("wrong OpMemoryModel operand"),
60            Error::WrongOpNameOperand => Cow::Borrowed("wrong OpName operand"),
61            Error::FunctionNotFound => Cow::Borrowed("can't find the function"),
62            Error::BlockNotFound => Cow::Borrowed("can't find the block"),
63        }
64    }
65}
66
67impl error::Error for Error {}
68
69impl fmt::Display for Error {
70    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        write!(f, "{}", self.describe())
72    }
73}
74
75/// The data representation loader.
76///
77/// Constructs a [`Module`](struct.Module.html) from the module header and
78/// instructions.
79///
80/// It implements the [`Consumer`](../binary/trait.Consumer.html) trait and
81/// works with the [`Parser`](../binary/struct.Parser.html).
82#[derive(Default)]
83pub struct Loader {
84    module: dr::Module,
85    function: Option<dr::Function>,
86    block: Option<dr::Block>,
87}
88
89impl Loader {
90    /// Creates a new empty loader.
91    pub fn new() -> Loader {
92        Loader {
93            module: dr::Module::new(),
94            function: None,
95            block: None,
96        }
97    }
98
99    /// Returns the `Module` under construction.
100    pub fn module(self) -> dr::Module {
101        self.module
102    }
103}
104
105/// Returns `$error` if `$condition` evaluates to false.
106macro_rules! if_ret_err {
107    ($condition: expr, $error: ident) => {
108        if $condition {
109            return ParseAction::Error(Box::new(Error::$error));
110        }
111    };
112}
113
114impl binary::Consumer for Loader {
115    fn initialize(&mut self) -> ParseAction {
116        ParseAction::Continue
117    }
118
119    fn finalize(&mut self) -> ParseAction {
120        if_ret_err!(self.block.is_some(), UnclosedBlock);
121        if_ret_err!(self.function.is_some(), UnclosedFunction);
122        ParseAction::Continue
123    }
124
125    fn consume_header(&mut self, header: dr::ModuleHeader) -> ParseAction {
126        self.module.header = Some(header);
127        ParseAction::Continue
128    }
129
130    fn consume_instruction(&mut self, inst: dr::Instruction) -> ParseAction {
131        let opcode = inst.class.opcode;
132        match opcode {
133            spirv::Op::Capability => self.module.capabilities.push(inst),
134            spirv::Op::Extension => self.module.extensions.push(inst),
135            spirv::Op::ExtInstImport => self.module.ext_inst_imports.push(inst),
136            spirv::Op::MemoryModel => self.module.memory_model = Some(inst),
137            spirv::Op::EntryPoint => self.module.entry_points.push(inst),
138            spirv::Op::ExecutionMode => self.module.execution_modes.push(inst),
139            spirv::Op::String
140            | spirv::Op::SourceExtension
141            | spirv::Op::Source
142            | spirv::Op::SourceContinued => self.module.debug_string_source.push(inst),
143            spirv::Op::Name | spirv::Op::MemberName => self.module.debug_names.push(inst),
144            spirv::Op::ModuleProcessed => self.module.debug_module_processed.push(inst),
145            opcode if grammar::reflect::is_location_debug(opcode) => {
146                match &mut self.block {
147                    Some(block) => block.instructions.push(inst),
148                    // types_global_values is the only valid section (other than functions) that
149                    // OpLine/OpNoLine can be placed in, so put it there.
150                    None => self.module.types_global_values.push(inst),
151                }
152            }
153            opcode if opcode.is_annotation() => self.module.annotations.push(inst),
154            opcode if opcode.is_type() || opcode.is_constant() => {
155                self.module.types_global_values.push(inst)
156            }
157            spirv::Op::Variable if self.function.is_none() => {
158                self.module.types_global_values.push(inst)
159            }
160            spirv::Op::Undef if self.function.is_none() => {
161                self.module.types_global_values.push(inst)
162            }
163            spirv::Op::Function => {
164                if_ret_err!(self.function.is_some(), NestedFunction);
165                let mut f = dr::Function::new();
166                f.def = Some(inst);
167                self.function = Some(f)
168            }
169            spirv::Op::FunctionEnd => {
170                if_ret_err!(self.function.is_none(), MismatchedFunctionEnd);
171                if_ret_err!(self.block.is_some(), UnclosedBlock);
172                self.function.as_mut().unwrap().end = Some(inst);
173                self.module.functions.push(self.function.take().unwrap())
174            }
175            spirv::Op::FunctionParameter => {
176                if_ret_err!(self.function.is_none(), DetachedFunctionParameter);
177                self.function.as_mut().unwrap().parameters.push(inst);
178            }
179            spirv::Op::Label => {
180                if_ret_err!(self.function.is_none(), DetachedBlock);
181                if_ret_err!(self.block.is_some(), NestedBlock);
182                let mut block = dr::Block::new();
183                block.label = Some(inst);
184                self.block = Some(block)
185            }
186            opcode if grammar::reflect::is_block_terminator(opcode) => {
187                // Make sure the block exists here. Once the block exists,
188                // we are certain the function exists because the above checks.
189                if_ret_err!(self.block.is_none(), MismatchedTerminator);
190                self.block.as_mut().unwrap().instructions.push(inst);
191                self.function
192                    .as_mut()
193                    .unwrap()
194                    .blocks
195                    .push(self.block.take().unwrap())
196            }
197            _ => {
198                if self.block.is_none() {
199                    return ParseAction::Error(Box::new(Error::DetachedInstruction(Some(inst))));
200                }
201                self.block.as_mut().unwrap().instructions.push(inst)
202            }
203        }
204        ParseAction::Continue
205    }
206}
207
208/// Loads the SPIR-V `binary` into memory and returns a `Module`.
209///
210/// # Examples
211///
212/// ```
213/// use rspirv;
214/// use rspirv::binary::Disassemble;
215///
216/// let buffer: Vec<u8> = vec![
217///     // Magic number.           Version number: 1.0.
218///     0x03, 0x02, 0x23, 0x07,    0x00, 0x00, 0x01, 0x00,
219///     // Generator number: 0.    Bound: 0.
220///     0x00, 0x00, 0x00, 0x00,    0x00, 0x00, 0x00, 0x00,
221///     // Reserved word: 0.
222///     0x00, 0x00, 0x00, 0x00,
223///     // OpMemoryModel.          Logical.
224///     0x0e, 0x00, 0x03, 0x00,    0x00, 0x00, 0x00, 0x00,
225///     // GLSL450.
226///     0x01, 0x00, 0x00, 0x00];
227///
228/// let dis = match rspirv::dr::load_bytes(buffer) {
229///     Ok(module) => module.disassemble(),
230///     Err(err) => format!("{}", err),
231/// };
232///
233/// assert_eq!(dis,
234///            "; SPIR-V\n\
235///             ; Version: 1.0\n\
236///             ; Generator: rspirv\n\
237///             ; Bound: 0\n\
238///             OpMemoryModel Logical GLSL450");
239/// ```
240pub fn load_bytes(binary: impl AsRef<[u8]>) -> ParseResult<dr::Module> {
241    let mut loader = Loader::new();
242    binary::parse_bytes(binary, &mut loader)?;
243    Ok(loader.module())
244}
245
246/// Loads the SPIR-V `binary` into memory and returns a `Module`.
247///
248/// # Examples
249///
250/// ```
251/// use rspirv;
252/// use rspirv::binary::Disassemble;
253///
254/// let buffer: Vec<u32> = vec![
255///     0x07230203,  // Magic number
256///     0x00010000,  // Version number: 1.0
257///     0x00000000,  // Generator number: 0
258///     0x00000000,  // Bound: 0
259///     0x00000000,  // Reserved word: 0
260///     0x0003000e,  // OpMemoryModel
261///     0x00000000,  // Logical
262///     0x00000001,  // GLSL450
263/// ];
264///
265/// let dis = match rspirv::dr::load_words(buffer) {
266///     Ok(module) => module.disassemble(),
267///     Err(err) => format!("{}", err),
268/// };
269///
270/// assert_eq!(dis,
271///            "; SPIR-V\n\
272///             ; Version: 1.0\n\
273///             ; Generator: rspirv\n\
274///             ; Bound: 0\n\
275///             OpMemoryModel Logical GLSL450");
276/// ```
277pub fn load_words(binary: impl AsRef<[u32]>) -> ParseResult<dr::Module> {
278    let mut loader = Loader::new();
279    binary::parse_words(binary, &mut loader)?;
280    Ok(loader.module())
281}
282
283#[cfg(test)]
284mod tests {
285    use crate::dr;
286    use crate::spirv;
287
288    #[test]
289    fn test_load_variable() {
290        let mut b = dr::Builder::new();
291
292        let void = b.type_void();
293        let float = b.type_float(32, None);
294        let voidfvoid = b.type_function(void, vec![void]);
295
296        // Global variable
297        let global = b.variable(float, None, spirv::StorageClass::Input, None);
298
299        b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
300            .unwrap();
301        b.begin_block(None).unwrap();
302        // Local variable
303        let local = b.variable(float, None, spirv::StorageClass::Function, None);
304        b.ret().unwrap();
305        b.end_function().unwrap();
306
307        let m = b.module();
308
309        assert_eq!(m.types_global_values.len(), 4);
310        let inst = &m.types_global_values[3];
311        assert_eq!(inst.class.opcode, spirv::Op::Variable);
312        assert_eq!(inst.result_id.unwrap(), global);
313
314        assert_eq!(m.functions.len(), 1);
315        let f = &m.functions[0];
316        assert_eq!(f.blocks.len(), 1);
317        let bb = &f.blocks[0];
318        assert!(bb.instructions.len() > 1);
319        let inst = &bb.instructions[0];
320        assert_eq!(inst.class.opcode, spirv::Op::Variable);
321        assert_eq!(inst.result_id.unwrap(), local);
322    }
323
324    #[test]
325    fn test_load_undef() {
326        let mut b = dr::Builder::new();
327
328        let void = b.type_void();
329        let float = b.type_float(32, None);
330        let voidfvoid = b.type_function(void, vec![void]);
331
332        // Global variable
333        let global = b.undef(float, None);
334
335        b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
336            .unwrap();
337        b.begin_block(None).unwrap();
338        // Local variable
339        let local = b.undef(float, None);
340        b.ret().unwrap();
341        b.end_function().unwrap();
342
343        let m = b.module();
344
345        assert_eq!(m.types_global_values.len(), 4);
346        let inst = &m.types_global_values[3];
347        assert_eq!(inst.class.opcode, spirv::Op::Undef);
348        assert_eq!(inst.result_id.unwrap(), global);
349
350        assert_eq!(m.functions.len(), 1);
351        let f = &m.functions[0];
352        assert_eq!(f.blocks.len(), 1);
353        let bb = &f.blocks[0];
354        assert!(bb.instructions.len() > 1);
355        let inst = &bb.instructions[0];
356        assert_eq!(inst.class.opcode, spirv::Op::Undef);
357        assert_eq!(inst.result_id.unwrap(), local);
358    }
359}