rustc_codegen_spirv/
symbols.rs

1use crate::attr::{Entry, ExecutionModeExtra, IntrinsicType, SpecConstant, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_ast::ast::{LitIntType, LitKind, MetaItemInner, MetaItemLit};
5use rustc_data_structures::fx::FxHashMap;
6use rustc_hir::Attribute;
7use rustc_span::Span;
8use rustc_span::symbol::{Ident, Symbol};
9use std::rc::Rc;
10
11/// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords.
12/// Symbols are interned, as in, they don't actually store the string itself inside them, but rather an index into a
13/// global table of strings. Then, whenever a new Symbol is created, the global table is checked to see if the string
14/// already exists, deduplicating it if so. This makes things like comparison and cloning really cheap. So, this struct
15/// is to allocate all our keywords up front and intern them all, so we can do comparisons really easily and fast.
16pub struct Symbols {
17    pub discriminant: Symbol,
18    pub rust_gpu: Symbol,
19    pub spirv: Symbol,
20    pub libm: Symbol,
21    pub entry_point_name: Symbol,
22    pub spv_khr_vulkan_memory_model: Symbol,
23
24    descriptor_set: Symbol,
25    binding: Symbol,
26    input_attachment_index: Symbol,
27
28    spec_constant: Symbol,
29    id: Symbol,
30    default: Symbol,
31
32    attributes: FxHashMap<Symbol, SpirvAttribute>,
33    execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
34    pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
35}
36
37const BUILTINS: &[(&str, BuiltIn)] = {
38    use BuiltIn::*;
39    &[
40        ("position", Position),
41        ("point_size", PointSize),
42        ("clip_distance", ClipDistance),
43        ("cull_distance", CullDistance),
44        ("vertex_id", VertexId),
45        ("instance_id", InstanceId),
46        ("primitive_id", PrimitiveId),
47        ("invocation_id", InvocationId),
48        ("layer", Layer),
49        ("viewport_index", ViewportIndex),
50        ("tess_level_outer", TessLevelOuter),
51        ("tess_level_inner", TessLevelInner),
52        ("tess_coord", TessCoord),
53        ("patch_vertices", PatchVertices),
54        ("frag_coord", FragCoord),
55        ("point_coord", PointCoord),
56        ("front_facing", FrontFacing),
57        ("sample_id", SampleId),
58        ("sample_position", SamplePosition),
59        ("sample_mask", SampleMask),
60        ("frag_depth", FragDepth),
61        ("helper_invocation", HelperInvocation),
62        ("num_workgroups", NumWorkgroups),
63        // ("workgroup_size", WorkgroupSize), -- constant
64        ("workgroup_id", WorkgroupId),
65        ("local_invocation_id", LocalInvocationId),
66        ("global_invocation_id", GlobalInvocationId),
67        ("local_invocation_index", LocalInvocationIndex),
68        // ("work_dim", WorkDim), -- Kernel-only
69        // ("global_size", GlobalSize), -- Kernel-only
70        // ("enqueued_workgroup_size", EnqueuedWorkgroupSize), -- Kernel-only
71        // ("global_offset", GlobalOffset), -- Kernel-only
72        // ("global_linear_id", GlobalLinearId), -- Kernel-only
73        ("subgroup_size", SubgroupSize),
74        // ("subgroup_max_size", SubgroupMaxSize), -- Kernel-only
75        ("num_subgroups", NumSubgroups),
76        // ("num_enqueued_subgroups", NumEnqueuedSubgroups), -- Kernel-only
77        ("subgroup_id", SubgroupId),
78        ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
79        ("vertex_index", VertexIndex),
80        ("instance_index", InstanceIndex),
81        ("subgroup_eq_mask", SubgroupEqMask),
82        ("subgroup_ge_mask", SubgroupGeMask),
83        ("subgroup_gt_mask", SubgroupGtMask),
84        ("subgroup_le_mask", SubgroupLeMask),
85        ("subgroup_lt_mask", SubgroupLtMask),
86        ("base_vertex", BaseVertex),
87        ("base_instance", BaseInstance),
88        ("draw_index", DrawIndex),
89        ("device_index", DeviceIndex),
90        ("view_index", ViewIndex),
91        ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
92        (
93            "bary_coord_no_persp_centroid_amd",
94            BaryCoordNoPerspCentroidAMD,
95        ),
96        ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
97        ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
98        ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
99        ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
100        ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
101        ("frag_stencil_ref_ext", FragStencilRefEXT),
102        ("viewport_mask_nv", ViewportMaskNV),
103        ("secondary_position_nv", SecondaryPositionNV),
104        ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
105        ("position_per_view_nv", PositionPerViewNV),
106        ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
107        ("fully_covered_ext", FullyCoveredEXT),
108        ("task_count_nv", TaskCountNV),
109        ("primitive_count_nv", PrimitiveCountNV),
110        ("primitive_indices_nv", PrimitiveIndicesNV),
111        ("clip_distance_per_view_nv", ClipDistancePerViewNV),
112        ("cull_distance_per_view_nv", CullDistancePerViewNV),
113        ("layer_per_view_nv", LayerPerViewNV),
114        ("mesh_view_count_nv", MeshViewCountNV),
115        ("mesh_view_indices_nv", MeshViewIndicesNV),
116        ("bary_coord_nv", BuiltIn::BaryCoordNV),
117        ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
118        ("bary_coord", BaryCoordKHR),
119        ("bary_coord_no_persp", BaryCoordNoPerspKHR),
120        ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
121        ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
122        (
123            "primitive_triangle_indices_ext",
124            PrimitiveTriangleIndicesEXT,
125        ),
126        ("cull_primitive_ext", CullPrimitiveEXT),
127        ("frag_size_ext", FragSizeEXT),
128        ("frag_invocation_count_ext", FragInvocationCountEXT),
129        ("launch_id", BuiltIn::LaunchIdKHR),
130        ("launch_size", BuiltIn::LaunchSizeKHR),
131        ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
132        ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
133        ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
134        ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
135        ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
136        ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
137        ("ray_tmin", BuiltIn::RayTminKHR),
138        ("ray_tmax", BuiltIn::RayTmaxKHR),
139        ("object_to_world", BuiltIn::ObjectToWorldKHR),
140        ("world_to_object", BuiltIn::WorldToObjectKHR),
141        ("hit_kind", BuiltIn::HitKindKHR),
142        ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
143        ("warps_per_sm_nv", WarpsPerSMNV),
144        ("sm_count_nv", SMCountNV),
145        ("warp_id_nv", WarpIDNV),
146        ("SMIDNV", SMIDNV),
147    ]
148};
149
150const STORAGE_CLASSES: &[(&str, StorageClass)] = {
151    use StorageClass::*;
152    &[
153        ("uniform_constant", UniformConstant),
154        ("input", Input),
155        ("uniform", Uniform),
156        ("output", Output),
157        ("workgroup", Workgroup),
158        ("cross_workgroup", CrossWorkgroup),
159        ("private", Private),
160        ("function", Function),
161        ("generic", Generic),
162        ("push_constant", PushConstant),
163        ("atomic_counter", AtomicCounter),
164        ("image", Image),
165        ("storage_buffer", StorageBuffer),
166        ("callable_data", StorageClass::CallableDataKHR),
167        (
168            "incoming_callable_data",
169            StorageClass::IncomingCallableDataKHR,
170        ),
171        ("ray_payload", StorageClass::RayPayloadKHR),
172        ("hit_attribute", StorageClass::HitAttributeKHR),
173        ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
174        ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
175        ("physical_storage_buffer", PhysicalStorageBuffer),
176        ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
177    ]
178};
179
180const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
181    use ExecutionModel::*;
182    &[
183        ("vertex", Vertex),
184        ("tessellation_control", TessellationControl),
185        ("tessellation_evaluation", TessellationEvaluation),
186        ("geometry", Geometry),
187        ("fragment", Fragment),
188        ("compute", GLCompute),
189        ("task_nv", TaskNV),
190        ("mesh_nv", MeshNV),
191        ("task_ext", TaskEXT),
192        ("mesh_ext", MeshEXT),
193        ("ray_generation", ExecutionModel::RayGenerationKHR),
194        ("intersection", ExecutionModel::IntersectionKHR),
195        ("any_hit", ExecutionModel::AnyHitKHR),
196        ("closest_hit", ExecutionModel::ClosestHitKHR),
197        ("miss", ExecutionModel::MissKHR),
198        ("callable", ExecutionModel::CallableKHR),
199    ]
200};
201
202#[derive(Copy, Clone, Debug)]
203enum ExecutionModeExtraDim {
204    None,
205    Value,
206    X,
207    Y,
208    Z,
209    Tuple,
210}
211
212const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
213    use ExecutionMode::*;
214    use ExecutionModeExtraDim::*;
215    &[
216        ("invocations", Invocations, Value),
217        ("spacing_equal", SpacingEqual, None),
218        ("spacing_fraction_even", SpacingFractionalEven, None),
219        ("spacing_fraction_odd", SpacingFractionalOdd, None),
220        ("vertex_order_cw", VertexOrderCw, None),
221        ("vertex_order_ccw", VertexOrderCcw, None),
222        ("pixel_center_integer", PixelCenterInteger, None),
223        ("origin_upper_left", OriginUpperLeft, None),
224        ("origin_lower_left", OriginLowerLeft, None),
225        ("early_fragment_tests", EarlyFragmentTests, None),
226        ("point_mode", PointMode, None),
227        ("xfb", Xfb, None),
228        ("depth_replacing", DepthReplacing, None),
229        ("depth_greater", DepthGreater, None),
230        ("depth_less", DepthLess, None),
231        ("depth_unchanged", DepthUnchanged, None),
232        ("threads", LocalSize, Tuple),
233        ("local_size_hint_x", LocalSizeHint, X),
234        ("local_size_hint_y", LocalSizeHint, Y),
235        ("local_size_hint_z", LocalSizeHint, Z),
236        ("input_points", InputPoints, None),
237        ("input_lines", InputLines, None),
238        ("input_lines_adjacency", InputLinesAdjacency, None),
239        ("triangles", Triangles, None),
240        ("input_triangles_adjacency", InputTrianglesAdjacency, None),
241        ("quads", Quads, None),
242        ("isolines", Isolines, None),
243        ("output_vertices", OutputVertices, Value),
244        ("output_points", OutputPoints, None),
245        ("output_line_strip", OutputLineStrip, None),
246        ("output_triangle_strip", OutputTriangleStrip, None),
247        ("vec_type_hint", VecTypeHint, Value),
248        ("contraction_off", ContractionOff, None),
249        ("initializer", Initializer, None),
250        ("finalizer", Finalizer, None),
251        ("subgroup_size", SubgroupSize, Value),
252        ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
253        ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
254        ("local_size_id_x", LocalSizeId, X),
255        ("local_size_id_y", LocalSizeId, Y),
256        ("local_size_id_z", LocalSizeId, Z),
257        ("local_size_hint_id", LocalSizeHintId, Value),
258        ("post_depth_coverage", PostDepthCoverage, None),
259        ("denorm_preserve", DenormPreserve, None),
260        ("denorm_flush_to_zero", DenormFlushToZero, Value),
261        (
262            "signed_zero_inf_nan_preserve",
263            SignedZeroInfNanPreserve,
264            Value,
265        ),
266        ("rounding_mode_rte", RoundingModeRTE, Value),
267        ("rounding_mode_rtz", RoundingModeRTZ, Value),
268        ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
269        ("output_lines_nv", OutputLinesNV, None),
270        ("output_primitives_nv", OutputPrimitivesNV, Value),
271        ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
272        ("output_triangles_nv", OutputTrianglesNV, None),
273        ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
274        (
275            "output_triangles_ext",
276            ExecutionMode::OutputTrianglesEXT,
277            None,
278        ),
279        (
280            "output_primitives_ext",
281            ExecutionMode::OutputPrimitivesEXT,
282            Value,
283        ),
284        (
285            "pixel_interlock_ordered_ext",
286            PixelInterlockOrderedEXT,
287            None,
288        ),
289        (
290            "pixel_interlock_unordered_ext",
291            PixelInterlockUnorderedEXT,
292            None,
293        ),
294        (
295            "sample_interlock_ordered_ext",
296            SampleInterlockOrderedEXT,
297            None,
298        ),
299        (
300            "sample_interlock_unordered_ext",
301            SampleInterlockUnorderedEXT,
302            None,
303        ),
304        (
305            "shading_rate_interlock_ordered_ext",
306            ShadingRateInterlockOrderedEXT,
307            None,
308        ),
309        (
310            "shading_rate_interlock_unordered_ext",
311            ShadingRateInterlockUnorderedEXT,
312            None,
313        ),
314        // Reserved
315        /*("max_workgroup_size_intel_x", MaxWorkgroupSizeINTEL, X),
316        ("max_workgroup_size_intel_y", MaxWorkgroupSizeINTEL, Y),
317        ("max_workgroup_size_intel_z", MaxWorkgroupSizeINTEL, Z),
318        ("max_work_dim_intel", MaxWorkDimINTEL, Value),
319        ("no_global_offset_intel", NoGlobalOffsetINTEL, None),
320        ("num_simd_workitems_intel", NumSIMDWorkitemsINTEL, Value),*/
321    ]
322};
323
324impl Symbols {
325    fn new() -> Self {
326        let builtins = BUILTINS
327            .iter()
328            .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
329        let storage_classes = STORAGE_CLASSES
330            .iter()
331            .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
332        let execution_models = EXECUTION_MODELS
333            .iter()
334            .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
335        let custom_attributes = [
336            (
337                "sampler",
338                SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
339            ),
340            (
341                "generic_image_type",
342                SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
343            ),
344            (
345                "acceleration_structure",
346                SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
347            ),
348            (
349                "ray_query",
350                SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
351            ),
352            ("block", SpirvAttribute::Block),
353            ("flat", SpirvAttribute::Flat),
354            ("invariant", SpirvAttribute::Invariant),
355            ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
356            (
357                "sampled_image",
358                SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
359            ),
360            (
361                "runtime_array",
362                SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
363            ),
364            (
365                "typed_buffer",
366                SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
367            ),
368            (
369                "matrix",
370                SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
371            ),
372            ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
373            (
374                "buffer_store_intrinsic",
375                SpirvAttribute::BufferStoreIntrinsic,
376            ),
377        ]
378        .iter()
379        .cloned();
380        let attributes_iter = builtins
381            .chain(storage_classes)
382            .chain(execution_models)
383            .chain(custom_attributes)
384            .map(|(a, b)| (Symbol::intern(a), b));
385        let mut attributes = FxHashMap::default();
386        for (a, b) in attributes_iter {
387            let old = attributes.insert(a, b);
388            // `.collect()` into a FxHashMap does not error on duplicates, so manually write out the
389            // loop here to error on duplicates.
390            assert!(old.is_none());
391        }
392        let mut execution_modes = FxHashMap::default();
393        for &(key, mode, dim) in EXECUTION_MODES {
394            let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
395            assert!(old.is_none());
396        }
397
398        let mut libm_intrinsics = FxHashMap::default();
399        for &(a, b) in libm_intrinsics::TABLE {
400            let old = libm_intrinsics.insert(Symbol::intern(a), b);
401            assert!(old.is_none());
402        }
403        Self {
404            discriminant: Symbol::intern("discriminant"),
405            rust_gpu: Symbol::intern("rust_gpu"),
406            spirv: Symbol::intern("spirv"),
407            libm: Symbol::intern("libm"),
408            entry_point_name: Symbol::intern("entry_point_name"),
409            spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
410
411            descriptor_set: Symbol::intern("descriptor_set"),
412            binding: Symbol::intern("binding"),
413            input_attachment_index: Symbol::intern("input_attachment_index"),
414
415            spec_constant: Symbol::intern("spec_constant"),
416            id: Symbol::intern("id"),
417            default: Symbol::intern("default"),
418
419            attributes,
420            execution_modes,
421            libm_intrinsics,
422        }
423    }
424
425    /// Obtain an `Rc` handle to the current thread's `Symbols` instance, which
426    /// will be shared between all `Symbols::get()` calls on the same thread.
427    ///
428    /// While this is relatively cheap, prefer caching it in e.g. `CodegenCx`,
429    /// rather than calling `get()` every time a field of `Symbols` is needed.
430    pub fn get() -> Rc<Self> {
431        thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
432        SYMBOLS.with(Rc::clone)
433    }
434}
435
436// FIXME(eddyb) find something nicer for the error type.
437type ParseAttrError = (Span, String);
438
439// FIXME(eddyb) maybe move this to `attr`?
440pub(crate) fn parse_attrs_for_checking<'a>(
441    sym: &'a Symbols,
442    attrs: &'a [Attribute],
443) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
444    attrs.iter().flat_map(move |attr| {
445        let (whole_attr_error, args) = match attr {
446            Attribute::Unparsed(item) => {
447                // #[...]
448                let s = &item.path.segments;
449                if s.len() > 1 && s[0].name == sym.rust_gpu {
450                    // #[rust_gpu ...]
451                    if s.len() != 2 || s[1].name != sym.spirv {
452                        // #[rust_gpu::...] but not #[rust_gpu::spirv]
453                        (
454                            Some(Err((
455                                attr.span(),
456                                "unknown `rust_gpu` attribute, expected `rust_gpu::spirv`"
457                                    .to_string(),
458                            ))),
459                            Default::default(),
460                        )
461                    } else if let Some(args) = attr.meta_item_list() {
462                        // #[rust_gpu::spirv(...)]
463                        (None, args)
464                    } else {
465                        // #[rust_gpu::spirv]
466                        (
467                            Some(Err((
468                                attr.span(),
469                                "#[rust_gpu::spirv(..)] attribute must have at least one argument"
470                                    .to_string(),
471                            ))),
472                            Default::default(),
473                        )
474                    }
475                } else {
476                    // #[...] but not #[rust_gpu ...]
477                    (None, Default::default())
478                }
479            }
480            Attribute::Parsed(_) => (None, Default::default()),
481        };
482
483        whole_attr_error
484            .into_iter()
485            .chain(args.into_iter().map(move |ref arg| {
486                let span = arg.span();
487                let parsed_attr = if arg.has_name(sym.descriptor_set) {
488                    SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
489                } else if arg.has_name(sym.binding) {
490                    SpirvAttribute::Binding(parse_attr_int_value(arg)?)
491                } else if arg.has_name(sym.input_attachment_index) {
492                    SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
493                } else if arg.has_name(sym.spec_constant) {
494                    SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
495                } else {
496                    let name = match arg.ident() {
497                        Some(i) => i,
498                        None => {
499                            return Err((
500                                span,
501                                "#[spirv(..)] attribute argument must be single identifier"
502                                    .to_string(),
503                            ));
504                        }
505                    };
506                    sym.attributes.get(&name.name).map_or_else(
507                        || Err((name.span, "unknown argument to spirv attribute".to_string())),
508                        |a| {
509                            Ok(match a {
510                                SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
511                                    parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
512                                ),
513                                _ => a.clone(),
514                            })
515                        },
516                    )?
517                };
518                Ok((span, parsed_attr))
519            }))
520    })
521}
522
523fn parse_spec_constant_attr(
524    sym: &Symbols,
525    arg: &MetaItemInner,
526) -> Result<SpecConstant, ParseAttrError> {
527    let mut id = None;
528    let mut default = None;
529
530    if let Some(attrs) = arg.meta_item_list() {
531        for attr in attrs {
532            if attr.has_name(sym.id) {
533                if id.is_none() {
534                    id = Some(parse_attr_int_value(attr)?);
535                } else {
536                    return Err((attr.span(), "`id` may only be specified once".into()));
537                }
538            } else if attr.has_name(sym.default) {
539                if default.is_none() {
540                    default = Some(parse_attr_int_value(attr)?);
541                } else {
542                    return Err((attr.span(), "`default` may only be specified once".into()));
543                }
544            } else {
545                return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
546            }
547        }
548    }
549    Ok(SpecConstant {
550        id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
551        default,
552    })
553}
554
555fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
556    let arg = match arg.meta_item() {
557        Some(arg) => arg,
558        None => return Err((arg.span(), "attribute must have value".to_string())),
559    };
560    match arg.name_value_literal() {
561        Some(&MetaItemLit {
562            kind: LitKind::Int(x, LitIntType::Unsuffixed),
563            ..
564        }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
565        _ => Err((arg.span, "attribute value must be integer".to_string())),
566    }
567}
568
569fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
570    let arg = match arg.meta_item() {
571        Some(arg) => arg,
572        None => return Err((arg.span(), "attribute must have value".to_string())),
573    };
574    match arg.meta_item_list() {
575        Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
576            let mut local_size = [1; 3];
577            for (idx, lit) in tuple.iter().enumerate() {
578                match lit {
579                    MetaItemInner::Lit(MetaItemLit {
580                        kind: LitKind::Int(x, LitIntType::Unsuffixed),
581                        ..
582                    }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
583                    _ => return Err((lit.span(), "must be a u32 literal".to_string())),
584                }
585            }
586            Ok(local_size)
587        }
588        Some([]) => Err((
589            arg.span,
590            "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
591        )),
592        Some(tuple) if tuple.len() > 3 => Err((
593            arg.span,
594            "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
595        )),
596        _ => Err((
597            arg.span,
598            "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
599        )),
600    }
601}
602
603// for a given entry, gather up the additional attributes
604// in this case ExecutionMode's, some have extra arguments
605// others are specified with x, y, or z components
606// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
607fn parse_entry_attrs(
608    sym: &Symbols,
609    arg: &MetaItemInner,
610    name: &Ident,
611    execution_model: ExecutionModel,
612) -> Result<Entry, ParseAttrError> {
613    use ExecutionMode::*;
614    use ExecutionModel::*;
615    let mut entry = Entry::from(execution_model);
616    let mut origin_mode: Option<ExecutionMode> = None;
617    let mut local_size: Option<[u32; 3]> = None;
618    let mut local_size_hint: Option<[u32; 3]> = None;
619    // Reserved
620    //let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
621    if let Some(attrs) = arg.meta_item_list() {
622        for attr in attrs {
623            if let Some(attr_name) = attr.ident() {
624                if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
625                {
626                    use ExecutionModeExtraDim::*;
627                    let val = match extra_dim {
628                        None | Tuple => Option::None,
629                        _ => Some(parse_attr_int_value(attr)?),
630                    };
631                    match execution_mode {
632                        OriginUpperLeft | OriginLowerLeft => {
633                            origin_mode.replace(*execution_mode);
634                        }
635                        LocalSize => {
636                            if local_size.is_none() {
637                                local_size.replace(parse_local_size_attr(attr)?);
638                            } else {
639                                return Err((
640                                    attr_name.span,
641                                    String::from(
642                                        "`#[spirv(compute(threads))]` may only be specified once",
643                                    ),
644                                ));
645                            }
646                        }
647                        LocalSizeHint => {
648                            let val = val.unwrap();
649                            if local_size_hint.is_none() {
650                                local_size_hint.replace([1, 1, 1]);
651                            }
652                            let local_size_hint = local_size_hint.as_mut().unwrap();
653                            match extra_dim {
654                                X => {
655                                    local_size_hint[0] = val;
656                                }
657                                Y => {
658                                    local_size_hint[1] = val;
659                                }
660                                Z => {
661                                    local_size_hint[2] = val;
662                                }
663                                _ => unreachable!(),
664                            }
665                        }
666                        // Reserved
667                        /*MaxWorkgroupSizeINTEL => {
668                            let val = val.unwrap();
669                            if max_workgroup_size_intel.is_none() {
670                                max_workgroup_size_intel.replace([1, 1, 1]);
671                            }
672                            let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
673                                .unwrap();
674                            match extra_dim {
675                                X => {
676                                    max_workgroup_size_intel[0] = val;
677                                },
678                                Y => {
679                                    max_workgroup_size_intel[1] = val;
680                                },
681                                Z => {
682                                    max_workgroup_size_intel[2] = val;
683                                },
684                                _ => unreachable!(),
685                            }
686                        },*/
687                        _ => {
688                            if let Some(val) = val {
689                                entry
690                                    .execution_modes
691                                    .push((*execution_mode, ExecutionModeExtra::new([val])));
692                            } else {
693                                entry
694                                    .execution_modes
695                                    .push((*execution_mode, ExecutionModeExtra::new([])));
696                            }
697                        }
698                    }
699                } else if attr_name.name == sym.entry_point_name {
700                    match attr.value_str() {
701                        Some(sym) => {
702                            entry.name = Some(sym);
703                        }
704                        None => {
705                            return Err((
706                                attr_name.span,
707                                format!(
708                                    "#[spirv({name}(..))] unknown attribute argument {attr_name}"
709                                ),
710                            ));
711                        }
712                    }
713                } else {
714                    return Err((
715                        attr_name.span,
716                        format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
717                    ));
718                }
719            } else {
720                return Err((
721                    arg.span(),
722                    format!("#[spirv({name}(..))] attribute argument must be single identifier"),
723                ));
724            }
725        }
726    }
727    match entry.execution_model {
728        Fragment => {
729            let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
730            entry
731                .execution_modes
732                .push((origin_mode, ExecutionModeExtra::new([])));
733        }
734        GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
735            if let Some(local_size) = local_size {
736                entry
737                    .execution_modes
738                    .push((LocalSize, ExecutionModeExtra::new(local_size)));
739            } else {
740                return Err((
741                    arg.span(),
742                    String::from(
743                        "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
744                    ),
745                ));
746            }
747        }
748        //TODO: Cover more defaults
749        _ => {}
750    }
751    Ok(entry)
752}