1use crate::binary;
2use crate::dr;
3use crate::grammar;
4use crate::spirv;
5
6use crate::binary::{ParseAction, ParseResult};
7use std::{borrow::Cow, error, fmt};
8
9#[derive(Debug)]
11pub enum Error {
12 NestedFunction,
13 UnclosedFunction,
14 MismatchedFunctionEnd,
15 DetachedFunctionParameter,
16 DetachedBlock,
17 NestedBlock,
18 UnclosedBlock,
19 MismatchedTerminator,
20 DetachedInstruction(Option<dr::Instruction>),
21 EmptyInstructionList,
22 WrongOpCapabilityOperand,
23 WrongOpExtensionOperand,
24 WrongOpExtInstImportOperand,
25 WrongOpMemoryModelOperand,
26 WrongOpNameOperand,
27 FunctionNotFound,
28 BlockNotFound,
29}
30
31impl Error {
32 fn describe(&self) -> Cow<'static, str> {
37 match self {
38 Error::NestedFunction => Cow::Borrowed("found nested function"),
39 Error::UnclosedFunction => Cow::Borrowed("found unclosed function"),
40 Error::MismatchedFunctionEnd => Cow::Borrowed("found mismatched OpFunctionEnd"),
41 Error::DetachedFunctionParameter => {
42 Cow::Borrowed("found function OpFunctionParameter not inside function")
43 }
44 Error::DetachedBlock => Cow::Borrowed("found block not inside function"),
45 Error::NestedBlock => Cow::Borrowed("found nested block"),
46 Error::UnclosedBlock => Cow::Borrowed("found block without terminator"),
47 Error::MismatchedTerminator => Cow::Borrowed("found mismatched terminator"),
48 Error::DetachedInstruction(Some(inst)) => Cow::Owned(format!(
49 "found instruction `{:?}` not inside block",
50 inst.class.opname
51 )),
52 Error::DetachedInstruction(None) => {
53 Cow::Borrowed("found unknown instruction not inside block")
54 }
55 Error::EmptyInstructionList => Cow::Borrowed("list of instructions is empty"),
56 Error::WrongOpCapabilityOperand => Cow::Borrowed("wrong OpCapability operand"),
57 Error::WrongOpExtensionOperand => Cow::Borrowed("wrong OpExtension operand"),
58 Error::WrongOpExtInstImportOperand => Cow::Borrowed("wrong OpExtInstImport operand"),
59 Error::WrongOpMemoryModelOperand => Cow::Borrowed("wrong OpMemoryModel operand"),
60 Error::WrongOpNameOperand => Cow::Borrowed("wrong OpName operand"),
61 Error::FunctionNotFound => Cow::Borrowed("can't find the function"),
62 Error::BlockNotFound => Cow::Borrowed("can't find the block"),
63 }
64 }
65}
66
67impl error::Error for Error {}
68
69impl fmt::Display for Error {
70 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71 write!(f, "{}", self.describe())
72 }
73}
74
75#[derive(Default)]
83pub struct Loader {
84 module: dr::Module,
85 function: Option<dr::Function>,
86 block: Option<dr::Block>,
87}
88
89impl Loader {
90 pub fn new() -> Loader {
92 Loader {
93 module: dr::Module::new(),
94 function: None,
95 block: None,
96 }
97 }
98
99 pub fn module(self) -> dr::Module {
101 self.module
102 }
103}
104
105macro_rules! if_ret_err {
107 ($condition: expr, $error: ident) => {
108 if $condition {
109 return ParseAction::Error(Box::new(Error::$error));
110 }
111 };
112}
113
114impl binary::Consumer for Loader {
115 fn initialize(&mut self) -> ParseAction {
116 ParseAction::Continue
117 }
118
119 fn finalize(&mut self) -> ParseAction {
120 if_ret_err!(self.block.is_some(), UnclosedBlock);
121 if_ret_err!(self.function.is_some(), UnclosedFunction);
122 ParseAction::Continue
123 }
124
125 fn consume_header(&mut self, header: dr::ModuleHeader) -> ParseAction {
126 self.module.header = Some(header);
127 ParseAction::Continue
128 }
129
130 fn consume_instruction(&mut self, inst: dr::Instruction) -> ParseAction {
131 let opcode = inst.class.opcode;
132 match opcode {
133 spirv::Op::Capability => self.module.capabilities.push(inst),
134 spirv::Op::Extension => self.module.extensions.push(inst),
135 spirv::Op::ExtInstImport => self.module.ext_inst_imports.push(inst),
136 spirv::Op::MemoryModel => self.module.memory_model = Some(inst),
137 spirv::Op::EntryPoint => self.module.entry_points.push(inst),
138 spirv::Op::ExecutionMode => self.module.execution_modes.push(inst),
139 spirv::Op::String
140 | spirv::Op::SourceExtension
141 | spirv::Op::Source
142 | spirv::Op::SourceContinued => self.module.debug_string_source.push(inst),
143 spirv::Op::Name | spirv::Op::MemberName => self.module.debug_names.push(inst),
144 spirv::Op::ModuleProcessed => self.module.debug_module_processed.push(inst),
145 opcode if grammar::reflect::is_location_debug(opcode) => {
146 match &mut self.block {
147 Some(block) => block.instructions.push(inst),
148 None => self.module.types_global_values.push(inst),
151 }
152 }
153 opcode if grammar::reflect::is_annotation(opcode) => self.module.annotations.push(inst),
154 opcode
155 if grammar::reflect::is_type(opcode) || grammar::reflect::is_constant(opcode) =>
156 {
157 self.module.types_global_values.push(inst)
158 }
159 spirv::Op::Variable if self.function.is_none() => {
160 self.module.types_global_values.push(inst)
161 }
162 spirv::Op::Undef if self.function.is_none() => {
163 self.module.types_global_values.push(inst)
164 }
165 spirv::Op::Function => {
166 if_ret_err!(self.function.is_some(), NestedFunction);
167 let mut f = dr::Function::new();
168 f.def = Some(inst);
169 self.function = Some(f)
170 }
171 spirv::Op::FunctionEnd => {
172 if_ret_err!(self.function.is_none(), MismatchedFunctionEnd);
173 if_ret_err!(self.block.is_some(), UnclosedBlock);
174 self.function.as_mut().unwrap().end = Some(inst);
175 self.module.functions.push(self.function.take().unwrap())
176 }
177 spirv::Op::FunctionParameter => {
178 if_ret_err!(self.function.is_none(), DetachedFunctionParameter);
179 self.function.as_mut().unwrap().parameters.push(inst);
180 }
181 spirv::Op::Label => {
182 if_ret_err!(self.function.is_none(), DetachedBlock);
183 if_ret_err!(self.block.is_some(), NestedBlock);
184 let mut block = dr::Block::new();
185 block.label = Some(inst);
186 self.block = Some(block)
187 }
188 opcode if grammar::reflect::is_block_terminator(opcode) => {
189 if_ret_err!(self.block.is_none(), MismatchedTerminator);
192 self.block.as_mut().unwrap().instructions.push(inst);
193 self.function
194 .as_mut()
195 .unwrap()
196 .blocks
197 .push(self.block.take().unwrap())
198 }
199 _ => {
200 if self.block.is_none() {
201 return ParseAction::Error(Box::new(Error::DetachedInstruction(Some(inst))));
202 }
203 self.block.as_mut().unwrap().instructions.push(inst)
204 }
205 }
206 ParseAction::Continue
207 }
208}
209
210pub fn load_bytes(binary: impl AsRef<[u8]>) -> ParseResult<dr::Module> {
243 let mut loader = Loader::new();
244 binary::parse_bytes(binary, &mut loader)?;
245 Ok(loader.module())
246}
247
248pub fn load_words(binary: impl AsRef<[u32]>) -> ParseResult<dr::Module> {
280 let mut loader = Loader::new();
281 binary::parse_words(binary, &mut loader)?;
282 Ok(loader.module())
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::dr;
288 use crate::spirv;
289
290 #[test]
291 fn test_load_variable() {
292 let mut b = dr::Builder::new();
293
294 let void = b.type_void();
295 let float = b.type_float(32);
296 let voidfvoid = b.type_function(void, vec![void]);
297
298 let global = b.variable(float, None, spirv::StorageClass::Input, None);
300
301 b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
302 .unwrap();
303 b.begin_block(None).unwrap();
304 let local = b.variable(float, None, spirv::StorageClass::Function, None);
306 b.ret().unwrap();
307 b.end_function().unwrap();
308
309 let m = b.module();
310
311 assert_eq!(m.types_global_values.len(), 4);
312 let inst = &m.types_global_values[3];
313 assert_eq!(inst.class.opcode, spirv::Op::Variable);
314 assert_eq!(inst.result_id.unwrap(), global);
315
316 assert_eq!(m.functions.len(), 1);
317 let f = &m.functions[0];
318 assert_eq!(f.blocks.len(), 1);
319 let bb = &f.blocks[0];
320 assert!(bb.instructions.len() > 1);
321 let inst = &bb.instructions[0];
322 assert_eq!(inst.class.opcode, spirv::Op::Variable);
323 assert_eq!(inst.result_id.unwrap(), local);
324 }
325
326 #[test]
327 fn test_load_undef() {
328 let mut b = dr::Builder::new();
329
330 let void = b.type_void();
331 let float = b.type_float(32);
332 let voidfvoid = b.type_function(void, vec![void]);
333
334 let global = b.undef(float, None);
336
337 b.begin_function(void, None, spirv::FunctionControl::NONE, voidfvoid)
338 .unwrap();
339 b.begin_block(None).unwrap();
340 let local = b.undef(float, None);
342 b.ret().unwrap();
343 b.end_function().unwrap();
344
345 let m = b.module();
346
347 assert_eq!(m.types_global_values.len(), 4);
348 let inst = &m.types_global_values[3];
349 assert_eq!(inst.class.opcode, spirv::Op::Undef);
350 assert_eq!(inst.result_id.unwrap(), global);
351
352 assert_eq!(m.functions.len(), 1);
353 let f = &m.functions[0];
354 assert_eq!(f.blocks.len(), 1);
355 let bb = &f.blocks[0];
356 assert!(bb.instructions.len() > 1);
357 let inst = &bb.instructions[0];
358 assert_eq!(inst.class.opcode, spirv::Op::Undef);
359 assert_eq!(inst.result_id.unwrap(), local);
360 }
361}