rspirv/lift/
mod.rs

1//! Infrastructure of lifting the data representation (DR) into structured
2//! representation (SR).
3
4mod storage;
5
6use self::storage::LiftStorage;
7use crate::{
8    dr,
9    sr::{instructions, module, ops, storage::Token, Constant, StructMember, Type},
10};
11
12use std::{borrow::Borrow, mem};
13
14/// A structure that we associate an <id> with, containing
15/// both the operation token and the result type.
16struct OpInfo {
17    op: Token<ops::Op>,
18    ty: Option<Token<Type>>,
19}
20
21impl Borrow<Token<ops::Op>> for OpInfo {
22    fn borrow(&self) -> &Token<ops::Op> {
23        &self.op
24    }
25}
26
27pub struct LiftContext {
28    //current_block: Option<Token<module::Block>>,
29    types: LiftStorage<Type>,
30    constants: LiftStorage<Constant>,
31    blocks: LiftStorage<module::Block>,
32    ops: LiftStorage<ops::Op, OpInfo>,
33}
34
35include!("autogen_context.rs");
36
37/// Error lifting a data representation of an operand into the structured
38/// representation.
39#[derive(Clone, Debug)]
40pub enum OperandError {
41    /// Operand has a wrong type.
42    WrongType,
43    /// Operand is an integer value that corresponds to a specified enum,
44    /// but the given integer is not known to have a mapping.
45    WrongEnumValue,
46    /// Operand is missing from the list.
47    Missing,
48}
49
50/// Error lifting a data representation of an instruction.
51#[derive(Clone, Debug)]
52pub enum InstructionError {
53    /// Instruction has a wrong opcode.
54    WrongOpcode,
55    /// Instruction is missing a result <id> or type.
56    MissingResult,
57    /// One of the operands can not be lifted.
58    Operand(OperandError),
59}
60
61impl From<OperandError> for InstructionError {
62    fn from(error: OperandError) -> Self {
63        InstructionError::Operand(error)
64    }
65}
66
67/// Error that may occur during the convesion from the data representation
68/// of a module into a structured representation.
69#[derive(Clone, Debug)]
70pub enum ConversionError {
71    MissingHeader,
72    MissingFunction,
73    MissingFunctionType,
74    MissingLabel,
75    MissingTerminator,
76    Instruction(InstructionError),
77}
78
79impl From<InstructionError> for ConversionError {
80    fn from(error: InstructionError) -> Self {
81        ConversionError::Instruction(error)
82    }
83}
84
85impl LiftContext {
86    /// Convert a module from the data representation into structured representation.
87    pub fn convert(module: &dr::Module) -> Result<module::Module, ConversionError> {
88        let mut context = LiftContext {
89            types: LiftStorage::new(),
90            constants: LiftStorage::new(),
91            blocks: LiftStorage::new(),
92            ops: LiftStorage::new(),
93        };
94        let mut functions = Vec::new();
95        let entry_points = Vec::new();
96
97        for inst in module.types_global_values.iter() {
98            match context.lift_type(inst) {
99                Ok(value) => {
100                    if let Some(id) = inst.result_id {
101                        context.types.append_id(id, value);
102                    }
103                    continue;
104                }
105                Err(InstructionError::WrongOpcode) => {}
106                Err(e) => panic!("Type lift error: {:?}", e),
107            }
108            match context.lift_constant(inst) {
109                Ok(value) => {
110                    if let Some(id) = inst.result_id {
111                        context.constants.append_id(id, value);
112                    }
113                    continue;
114                }
115                Err(InstructionError::WrongOpcode) => {}
116                Err(e) => panic!("Constant lift error: {:?}", e),
117            }
118        }
119
120        for fun in module.functions.iter() {
121            let def =
122                context.lift_function(fun.def.as_ref().ok_or(ConversionError::MissingFunction)?)?;
123            //TODO: lift function type instruction
124
125            for block in fun.blocks.iter() {
126                let mut arguments = Vec::new();
127                for inst in &block.instructions {
128                    match inst.class.opcode {
129                        spirv::Op::Line => {} // skip line decorations
130                        spirv::Op::Phi => {
131                            let ty = context.types.lookup_token(
132                                inst.result_type.ok_or(InstructionError::MissingResult)?,
133                            );
134                            arguments.push(ty);
135
136                            // Sanity-check if all source variables are of the same type
137                            for op in inst.operands.iter().step_by(2) {
138                                match op {
139                                    dr::Operand::IdRef(id) => {
140                                        if let Some((_, info)) = context.ops.lookup_safe(*id) {
141                                            assert_eq!(Some(ty), info.ty);
142                                        } else {
143                                            // let (v, info) =
144                                            //     context.constants.lookup_safe(*id).unwrap();
145                                            // TODO: Can't convert Constant back to their lowered type yet!
146                                            // assert_eq!(Some(ty), info.ty.as_ref());
147                                        }
148                                    }
149                                    _ => {
150                                        return Err(ConversionError::Instruction(
151                                            InstructionError::Operand(OperandError::Missing),
152                                        ))
153                                    }
154                                };
155                            }
156                        }
157                        _ => {
158                            if let Some(id) = inst.result_id {
159                                let op = context.lift_op(inst)?;
160                                let types = &context.types;
161                                let (token, entry) = context.ops.append(id, op);
162                                entry.insert(OpInfo {
163                                    op: token,
164                                    ty: inst.result_type.map(|ty| *types.lookup(ty).1),
165                                });
166                            }
167                        }
168                    }
169                }
170
171                let terminator = context.lift_terminator(
172                    block
173                        .instructions
174                        .last()
175                        .ok_or(ConversionError::MissingTerminator)?,
176                )?;
177
178                context.blocks.append_id(
179                    block.label.as_ref().unwrap().result_id.unwrap(),
180                    module::Block {
181                        arguments,
182                        ops: Vec::new(),
183                        terminator,
184                    },
185                );
186            }
187
188            let start_label = fun.blocks[0].label.as_ref().unwrap().result_id.unwrap();
189            let start_block = context.blocks.lookup_token(start_label);
190            let blocks = mem::replace(&mut context.blocks, LiftStorage::new()).unwrap();
191            let fun_ret = fun
192                .def
193                .as_ref()
194                .and_then(|d| d.result_type)
195                .expect("functions must have a result type");
196
197            functions.push(module::Function {
198                control: def.function_control,
199                result: context.types.lookup_token(fun_ret),
200                parameters: Vec::new(),
201                blocks,
202                start_block,
203            });
204        }
205
206        Ok(module::Module {
207            version: match module.header {
208                Some(ref header) => header.version,
209                None => return Err(ConversionError::MissingHeader),
210            },
211            capabilities: module
212                .capabilities
213                .iter()
214                .map(|cap| context.lift_capability(cap).map(|cap| cap.capability))
215                .collect::<Result<_, InstructionError>>()?,
216            extensions: Vec::new(),
217            ext_inst_imports: Vec::new(),
218            memory_model: match module.memory_model {
219                Some(ref mm) => context.lift_memory_model(mm)?,
220                None => return Err(ConversionError::MissingHeader),
221            },
222            entry_points,
223            types: context.types.unwrap(),
224            constants: context.constants.unwrap(),
225            ops: context.ops.unwrap(),
226            functions,
227        })
228    }
229
230    fn lookup_jump(&self, destination: spirv::Word) -> module::Jump {
231        let (_, block) = self.blocks.lookup(destination);
232        module::Jump {
233            block: *block,
234            arguments: Vec::new(), //TODO
235        }
236    }
237
238    fn lift_constant(&self, inst: &dr::Instruction) -> Result<Constant, InstructionError> {
239        match inst.class.opcode {
240            spirv::Op::ConstantTrue => Ok(Constant::Bool(true)),
241            spirv::Op::ConstantFalse => Ok(Constant::Bool(false)),
242            spirv::Op::Constant => {
243                match inst.result_type {
244                    Some(id) => {
245                        let oper = inst
246                            .operands
247                            .first()
248                            .ok_or(InstructionError::Operand(OperandError::Missing))?;
249                        let (value, width) = match *self.types.lookup(id).0 {
250                            Type::Int {
251                                signedness: 0,
252                                width,
253                            } => match *oper {
254                                dr::Operand::LiteralBit32(v) => (Constant::UInt(v), width),
255                                _ => {
256                                    return Err(InstructionError::Operand(OperandError::WrongType))
257                                }
258                            },
259                            Type::Int { width, .. } => match *oper {
260                                dr::Operand::LiteralBit32(v) => (Constant::Int(v as i32), width),
261                                _ => {
262                                    return Err(InstructionError::Operand(OperandError::WrongType))
263                                }
264                            },
265                            Type::Float { width } => match *oper {
266                                dr::Operand::LiteralBit32(v) => {
267                                    (Constant::Float(f32::from_bits(v)), width)
268                                }
269                                _ => {
270                                    return Err(InstructionError::Operand(OperandError::WrongType))
271                                }
272                            },
273                            _ => return Err(InstructionError::MissingResult),
274                        };
275                        if width > 32 {
276                            //log::warn!("Constant <id> {} doesn't fit in 32 bits", id);
277                        }
278                        Ok(value)
279                    }
280                    _ => Err(InstructionError::MissingResult),
281                }
282            }
283            spirv::Op::ConstantComposite => {
284                let mut vec = Vec::with_capacity(inst.operands.len());
285                for oper in inst.operands.iter() {
286                    let token = match *oper {
287                        dr::Operand::IdRef(v) => self.constants.lookup_token(v),
288                        _ => return Err(InstructionError::Operand(OperandError::WrongType)),
289                    };
290                    vec.push(token);
291                }
292                Ok(Constant::Composite(vec))
293            }
294            spirv::Op::ConstantSampler => {
295                if inst.operands.len() < 3 {
296                    return Err(InstructionError::Operand(OperandError::Missing));
297                }
298                Ok(Constant::Sampler {
299                    addressing_mode: match inst.operands[0] {
300                        dr::Operand::SamplerAddressingMode(v) => v,
301                        _ => return Err(InstructionError::Operand(OperandError::WrongType)),
302                    },
303                    normalized: match inst.operands[1] {
304                        dr::Operand::LiteralBit32(v) => v != 0,
305                        _ => return Err(InstructionError::Operand(OperandError::WrongType)),
306                    },
307                    filter_mode: match inst.operands[2] {
308                        dr::Operand::SamplerFilterMode(v) => v,
309                        _ => return Err(InstructionError::Operand(OperandError::WrongType)),
310                    },
311                })
312            }
313            spirv::Op::ConstantNull => Ok(Constant::Null),
314            spirv::Op::ConstantCompositeContinuedINTEL
315            | spirv::Op::SpecConstantCompositeContinuedINTEL => todo!(),
316            _ => Err(InstructionError::WrongOpcode),
317        }
318    }
319}