rspirv/binary/
assemble.rs

1use crate::dr;
2use std::convert::TryInto;
3
4/// Trait for assembling functionalities.
5pub trait Assemble {
6    /// Assembles the current object into the `result` vector, reducing the need for lots of allocations
7    fn assemble_into(&self, result: &mut Vec<u32>);
8
9    /// Assembles the current object and returns the binary code.
10    /// Helper method to remain backwards compatible, calls `assemble_into`
11    fn assemble(&self) -> Vec<u32> {
12        let mut v = vec![];
13        self.assemble_into(&mut v);
14        v
15    }
16}
17
18impl Assemble for dr::ModuleHeader {
19    fn assemble_into(&self, result: &mut Vec<u32>) {
20        result.extend([
21            self.magic_number,
22            self.version,
23            self.generator,
24            self.bound,
25            self.reserved_word,
26        ])
27    }
28}
29
30fn assemble_str(s: &str, result: &mut Vec<u32>) {
31    let chunks = s.as_bytes().chunks_exact(4);
32    let remainder = chunks.remainder();
33    let mut last = [0; 4];
34    last[..remainder.len()].copy_from_slice(remainder);
35    result.extend(chunks.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())));
36    result.push(u32::from_le_bytes(last));
37}
38
39impl Assemble for dr::Operand {
40    fn assemble_into(&self, result: &mut Vec<u32>) {
41        match *self {
42            Self::ImageOperands(v) => result.push(v.bits()),
43            Self::FPFastMathMode(v) => result.push(v.bits()),
44            Self::SelectionControl(v) => result.push(v.bits()),
45            Self::LoopControl(v) => result.push(v.bits()),
46            Self::FunctionControl(v) => result.push(v.bits()),
47            Self::MemorySemantics(v) => result.push(v.bits()),
48            Self::MemoryAccess(v) => result.push(v.bits()),
49            Self::KernelProfilingInfo(v) => result.push(v.bits()),
50            Self::CooperativeMatrixOperands(v) => result.push(v.bits()),
51            Self::SourceLanguage(v) => result.push(v as u32),
52            Self::ExecutionModel(v) => result.push(v as u32),
53            Self::AddressingModel(v) => result.push(v as u32),
54            Self::MemoryModel(v) => result.push(v as u32),
55            Self::ExecutionMode(v) => result.push(v as u32),
56            Self::StorageClass(v) => result.push(v as u32),
57            Self::Dim(v) => result.push(v as u32),
58            Self::SamplerAddressingMode(v) => result.push(v as u32),
59            Self::SamplerFilterMode(v) => result.push(v as u32),
60            Self::ImageFormat(v) => result.push(v as u32),
61            Self::ImageChannelOrder(v) => result.push(v as u32),
62            Self::ImageChannelDataType(v) => result.push(v as u32),
63            Self::FPRoundingMode(v) => result.push(v as u32),
64            Self::LinkageType(v) => result.push(v as u32),
65            Self::AccessQualifier(v) => result.push(v as u32),
66            Self::FunctionParameterAttribute(v) => result.push(v as u32),
67            Self::Decoration(v) => result.push(v as u32),
68            Self::BuiltIn(v) => result.push(v as u32),
69            Self::Scope(v) => result.push(v as u32),
70            Self::GroupOperation(v) => result.push(v as u32),
71            Self::KernelEnqueueFlags(v) => result.push(v as u32),
72            Self::Capability(v) => result.push(v as u32),
73            Self::IdMemorySemantics(v)
74            | Self::IdScope(v)
75            | Self::IdRef(v)
76            | Self::LiteralBit32(v)
77            | Self::LiteralExtInstInteger(v) => result.push(v),
78            Self::LiteralBit64(v) => result.extend([v as u32, (v >> 32) as u32]),
79            Self::LiteralSpecConstantOpInteger(v) => result.push(v as u32),
80            Self::LiteralString(ref v) => assemble_str(v, result),
81            Self::RayFlags(ref v) => result.push(v.bits()),
82            Self::RayQueryIntersection(v) => result.push(v as u32),
83            Self::RayQueryCommittedIntersectionType(v) => result.push(v as u32),
84            Self::RayQueryCandidateIntersectionType(v) => result.push(v as u32),
85            Self::FragmentShadingRate(v) => result.push(v.bits()),
86            Self::FPDenormMode(v) => result.push(v as u32),
87            Self::QuantizationModes(v) => result.push(v as u32),
88            Self::FPOperationMode(v) => result.push(v as u32),
89            Self::OverflowModes(v) => result.push(v as u32),
90            Self::PackedVectorFormat(v) => result.push(v as u32),
91            Self::HostAccessQualifier(v) => result.push(v as u32),
92            Self::CooperativeMatrixLayout(v) => result.push(v as u32),
93            Self::CooperativeMatrixUse(v) => result.push(v as u32),
94            Self::InitializationModeQualifier(v) => result.push(v as u32),
95            Self::LoadCacheControl(v) => result.push(v as u32),
96            Self::StoreCacheControl(v) => result.push(v as u32),
97        }
98    }
99}
100
101impl Assemble for dr::Instruction {
102    fn assemble_into(&self, result: &mut Vec<u32>) {
103        let start = result.len();
104        result.push(self.class.opcode as u32);
105        if let Some(r) = self.result_type {
106            result.push(r);
107        }
108        if let Some(r) = self.result_id {
109            result.push(r);
110        }
111        for operand in &self.operands {
112            operand.assemble_into(result);
113        }
114        let end = result.len() - start;
115        result[start] |= (end as u32) << 16;
116    }
117}
118
119impl Assemble for dr::Block {
120    fn assemble_into(&self, result: &mut Vec<u32>) {
121        if let Some(ref l) = self.label {
122            l.assemble_into(result);
123        }
124        for inst in &self.instructions {
125            inst.assemble_into(result);
126        }
127    }
128}
129
130impl Assemble for dr::Function {
131    fn assemble_into(&self, result: &mut Vec<u32>) {
132        if let Some(ref d) = self.def {
133            d.assemble_into(result);
134        }
135        for param in &self.parameters {
136            param.assemble_into(result);
137        }
138        for bb in &self.blocks {
139            bb.assemble_into(result);
140        }
141        if let Some(ref e) = self.end {
142            e.assemble_into(result);
143        }
144    }
145}
146
147impl Assemble for dr::Module {
148    fn assemble_into(&self, result: &mut Vec<u32>) {
149        if let Some(ref h) = self.header {
150            h.assemble_into(result);
151        }
152
153        for inst in self.global_inst_iter() {
154            inst.assemble_into(result);
155        }
156
157        for f in &self.functions {
158            f.assemble_into(result);
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use crate::dr;
166    use crate::spirv;
167
168    use super::assemble_str;
169    use crate::binary::Assemble;
170
171    #[test]
172    fn test_assemble_str() {
173        fn assemble_str_helper(s: &str) -> Vec<u32> {
174            let mut v = vec![];
175            assemble_str(s, &mut v);
176            v
177        }
178        assert_eq!(vec![0u32], assemble_str_helper(""));
179        assert_eq!(
180            vec![u32::from_le_bytes(*b"h\0\0\0")],
181            assemble_str_helper("h")
182        );
183        assert_eq!(
184            vec![u32::from_le_bytes(*b"hell"), 0u32],
185            assemble_str_helper("hell")
186        );
187        assert_eq!(
188            vec![
189                u32::from_le_bytes(*b"hell"),
190                u32::from_le_bytes(*b"o\0\0\0")
191            ],
192            assemble_str_helper("hello")
193        );
194    }
195
196    #[test]
197    fn test_assemble_operand_bitmask() {
198        let v = spirv::FunctionControl::DONT_INLINE;
199        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
200        let v = spirv::FunctionControl::PURE;
201        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
202        let v = spirv::FunctionControl::CONST;
203        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
204        let v = spirv::FunctionControl::DONT_INLINE | spirv::FunctionControl::CONST;
205        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
206        let v = spirv::FunctionControl::DONT_INLINE
207            | spirv::FunctionControl::PURE
208            | spirv::FunctionControl::CONST;
209        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
210    }
211
212    #[test]
213    fn test_assemble_operand_enum() {
214        assert_eq!(
215            vec![spirv::BuiltIn::Position as u32],
216            dr::Operand::BuiltIn(spirv::BuiltIn::Position).assemble()
217        );
218        assert_eq!(
219            vec![spirv::BuiltIn::PointSize as u32],
220            dr::Operand::BuiltIn(spirv::BuiltIn::PointSize).assemble()
221        );
222        assert_eq!(
223            vec![spirv::BuiltIn::InstanceId as u32],
224            dr::Operand::BuiltIn(spirv::BuiltIn::InstanceId).assemble()
225        );
226    }
227
228    fn wc_op(wc: u32, op: spirv::Op) -> u32 {
229        (wc << 16) | op as u32
230    }
231
232    // No operands
233    #[test]
234    fn test_assemble_inst_nop() {
235        assert_eq!(
236            vec![wc_op(1, spirv::Op::Nop)],
237            dr::Instruction::new(spirv::Op::Nop, None, None, vec![]).assemble()
238        );
239    }
240
241    // No result type and result id
242    #[test]
243    fn test_assemble_inst_memory_model() {
244        let operands = vec![
245            dr::Operand::AddressingModel(spirv::AddressingModel::Physical32),
246            dr::Operand::MemoryModel(spirv::MemoryModel::OpenCL),
247        ];
248        assert_eq!(
249            vec![
250                wc_op(3, spirv::Op::MemoryModel),
251                spirv::AddressingModel::Physical32 as u32,
252                spirv::MemoryModel::OpenCL as u32
253            ],
254            dr::Instruction::new(spirv::Op::MemoryModel, None, None, operands).assemble()
255        );
256    }
257
258    // No result type, having result id
259    #[test]
260    fn test_assemble_inst_type_int() {
261        let operands = vec![dr::Operand::LiteralBit32(32), dr::Operand::LiteralBit32(1)];
262        assert_eq!(
263            vec![wc_op(4, spirv::Op::TypeInt), 42, 32, 1],
264            dr::Instruction::new(spirv::Op::TypeInt, None, Some(42), operands).assemble()
265        );
266    }
267
268    // Having result type and id
269    #[test]
270    fn test_assemble_inst_iadd() {
271        let operands = vec![dr::Operand::IdRef(0xef), dr::Operand::IdRef(0x78)];
272        assert_eq!(
273            vec![wc_op(5, spirv::Op::IAdd), 0xab, 0xcd, 0xef, 0x78],
274            dr::Instruction::new(spirv::Op::IAdd, Some(0xab), Some(0xcd), operands).assemble()
275        );
276    }
277
278    #[test]
279    fn test_assemble_function_void() {
280        let mut b = dr::Builder::new();
281        b.memory_model(spirv::AddressingModel::Logical, spirv::MemoryModel::Simple);
282        let void = b.type_void();
283        let voidfvoid = b.type_function(void, vec![void]);
284        b.begin_function(void, None, spirv::FunctionControl::CONST, voidfvoid)
285            .unwrap();
286        b.begin_block(None).unwrap();
287        b.ret().unwrap();
288        b.end_function().unwrap();
289
290        assert_eq!(
291            vec![
292                spirv::MAGIC_NUMBER,
293                (u32::from(spirv::MAJOR_VERSION) << 16) | (u32::from(spirv::MINOR_VERSION) << 8),
294                0x000f0000,
295                5,
296                0,
297                wc_op(3, spirv::Op::MemoryModel),
298                spirv::AddressingModel::Logical as u32,
299                spirv::MemoryModel::Simple as u32,
300                wc_op(2, spirv::Op::TypeVoid),
301                1,
302                wc_op(4, spirv::Op::TypeFunction),
303                2,
304                1,
305                1,
306                wc_op(5, spirv::Op::Function),
307                1,
308                3,
309                spirv::FunctionControl::CONST.bits(),
310                2,
311                wc_op(2, spirv::Op::Label),
312                4,
313                wc_op(1, spirv::Op::Return),
314                wc_op(1, spirv::Op::FunctionEnd)
315            ],
316            b.module().assemble()
317        );
318    }
319
320    #[test]
321    fn test_assemble_function_parameters() {
322        let mut b = dr::Builder::new();
323        b.memory_model(spirv::AddressingModel::Logical, spirv::MemoryModel::Simple);
324        let float = b.type_float(32);
325        let ptr = b.type_pointer(None, spirv::StorageClass::Function, float);
326        let fff = b.type_function(float, vec![float, float]);
327        b.begin_function(float, None, spirv::FunctionControl::CONST, fff)
328            .unwrap();
329        let param1 = b.function_parameter(ptr).unwrap();
330        let param2 = b.function_parameter(ptr).unwrap();
331        b.begin_block(None).unwrap();
332        let v1 = b.load(float, None, param1, None, vec![]).unwrap();
333        let v2 = b.load(float, None, param2, None, vec![]).unwrap();
334        let v = b.f_add(float, None, v1, v2).unwrap();
335        b.ret_value(v).unwrap();
336        b.end_function().unwrap();
337
338        assert_eq!(
339            vec![
340                // Header
341                spirv::MAGIC_NUMBER,
342                (u32::from(spirv::MAJOR_VERSION) << 16) | (u32::from(spirv::MINOR_VERSION) << 8),
343                0x000f0000,
344                11, // bound
345                0,
346                // Instructions
347                wc_op(3, spirv::Op::MemoryModel),
348                spirv::AddressingModel::Logical as u32,
349                spirv::MemoryModel::Simple as u32,
350                wc_op(3, spirv::Op::TypeFloat),
351                1,  // result id
352                32, // bitwidth
353                wc_op(4, spirv::Op::TypePointer),
354                2, // result id
355                spirv::StorageClass::Function as u32,
356                1, // float result id
357                wc_op(5, spirv::Op::TypeFunction),
358                3, // result id
359                1, // result type
360                1, // parameter type
361                1, // parameter type
362                wc_op(5, spirv::Op::Function),
363                1, // result type id
364                4, // result id
365                spirv::FunctionControl::CONST.bits(),
366                3, // function type id
367                wc_op(3, spirv::Op::FunctionParameter),
368                2, // result type id
369                5, // result id
370                wc_op(3, spirv::Op::FunctionParameter),
371                2, // result type id
372                6, // result id
373                wc_op(2, spirv::Op::Label),
374                7, // result id
375                wc_op(4, spirv::Op::Load),
376                1, // result type id
377                8, // result id
378                5, // parameter id
379                wc_op(4, spirv::Op::Load),
380                1, // result type id
381                9, // result id
382                6, // parameter id
383                wc_op(5, spirv::Op::FAdd),
384                1,  // result type id
385                10, // result id
386                8,  // operand
387                9,  // operand
388                wc_op(2, spirv::Op::ReturnValue),
389                10,
390                wc_op(1, spirv::Op::FunctionEnd)
391            ],
392            b.module().assemble()
393        );
394    }
395}