rustc_codegen_spirv/
symbols.rs

1use crate::attr::{IntrinsicType, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_data_structures::fx::FxHashMap;
5use rustc_span::symbol::Symbol;
6use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
7use std::rc::Rc;
8
9/// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords.
10/// Symbols are interned, as in, they don't actually store the string itself inside them, but rather an index into a
11/// global table of strings. Then, whenever a new Symbol is created, the global table is checked to see if the string
12/// already exists, deduplicating it if so. This makes things like comparison and cloning really cheap. So, this struct
13/// is to allocate all our keywords up front and intern them all, so we can do comparisons really easily and fast.
14pub struct Symbols {
15    pub discriminant: Symbol,
16    pub rust_gpu: Symbol,
17    pub spirv_attr_with_version: Symbol,
18    pub vector: Symbol,
19    pub v1: Symbol,
20    pub libm: Symbol,
21    pub entry_point_name: Symbol,
22    pub spv_khr_vulkan_memory_model: Symbol,
23
24    pub descriptor_set: Symbol,
25    pub binding: Symbol,
26    pub input_attachment_index: Symbol,
27
28    pub spec_constant: Symbol,
29    pub id: Symbol,
30    pub default: Symbol,
31
32    pub attributes: FxHashMap<Symbol, SpirvAttribute>,
33    pub execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
34    pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
35}
36
37const BUILTINS: &[(&str, BuiltIn)] = {
38    use BuiltIn::*;
39    &[
40        ("position", Position),
41        ("point_size", PointSize),
42        ("clip_distance", ClipDistance),
43        ("cull_distance", CullDistance),
44        ("vertex_id", VertexId),
45        ("instance_id", InstanceId),
46        ("primitive_id", PrimitiveId),
47        ("invocation_id", InvocationId),
48        ("layer", Layer),
49        ("viewport_index", ViewportIndex),
50        ("tess_level_outer", TessLevelOuter),
51        ("tess_level_inner", TessLevelInner),
52        ("tess_coord", TessCoord),
53        ("patch_vertices", PatchVertices),
54        ("frag_coord", FragCoord),
55        ("point_coord", PointCoord),
56        ("front_facing", FrontFacing),
57        ("sample_id", SampleId),
58        ("sample_position", SamplePosition),
59        ("sample_mask", SampleMask),
60        ("frag_depth", FragDepth),
61        ("helper_invocation", HelperInvocation),
62        ("num_workgroups", NumWorkgroups),
63        // ("workgroup_size", WorkgroupSize), -- constant
64        ("workgroup_id", WorkgroupId),
65        ("local_invocation_id", LocalInvocationId),
66        ("global_invocation_id", GlobalInvocationId),
67        ("local_invocation_index", LocalInvocationIndex),
68        // ("work_dim", WorkDim), -- Kernel-only
69        // ("global_size", GlobalSize), -- Kernel-only
70        // ("enqueued_workgroup_size", EnqueuedWorkgroupSize), -- Kernel-only
71        // ("global_offset", GlobalOffset), -- Kernel-only
72        // ("global_linear_id", GlobalLinearId), -- Kernel-only
73        ("subgroup_size", SubgroupSize),
74        // ("subgroup_max_size", SubgroupMaxSize), -- Kernel-only
75        ("num_subgroups", NumSubgroups),
76        // ("num_enqueued_subgroups", NumEnqueuedSubgroups), -- Kernel-only
77        ("subgroup_id", SubgroupId),
78        ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
79        ("vertex_index", VertexIndex),
80        ("instance_index", InstanceIndex),
81        ("subgroup_eq_mask", SubgroupEqMask),
82        ("subgroup_ge_mask", SubgroupGeMask),
83        ("subgroup_gt_mask", SubgroupGtMask),
84        ("subgroup_le_mask", SubgroupLeMask),
85        ("subgroup_lt_mask", SubgroupLtMask),
86        ("base_vertex", BaseVertex),
87        ("base_instance", BaseInstance),
88        ("draw_index", DrawIndex),
89        ("device_index", DeviceIndex),
90        ("view_index", ViewIndex),
91        ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
92        (
93            "bary_coord_no_persp_centroid_amd",
94            BaryCoordNoPerspCentroidAMD,
95        ),
96        ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
97        ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
98        ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
99        ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
100        ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
101        ("frag_stencil_ref_ext", FragStencilRefEXT),
102        ("viewport_mask_nv", ViewportMaskNV),
103        ("secondary_position_nv", SecondaryPositionNV),
104        ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
105        ("position_per_view_nv", PositionPerViewNV),
106        ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
107        ("fully_covered_ext", FullyCoveredEXT),
108        ("task_count_nv", TaskCountNV),
109        ("primitive_count_nv", PrimitiveCountNV),
110        ("primitive_indices_nv", PrimitiveIndicesNV),
111        ("clip_distance_per_view_nv", ClipDistancePerViewNV),
112        ("cull_distance_per_view_nv", CullDistancePerViewNV),
113        ("layer_per_view_nv", LayerPerViewNV),
114        ("mesh_view_count_nv", MeshViewCountNV),
115        ("mesh_view_indices_nv", MeshViewIndicesNV),
116        ("bary_coord_nv", BuiltIn::BaryCoordNV),
117        ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
118        ("bary_coord", BaryCoordKHR),
119        ("bary_coord_no_persp", BaryCoordNoPerspKHR),
120        ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
121        ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
122        (
123            "primitive_triangle_indices_ext",
124            PrimitiveTriangleIndicesEXT,
125        ),
126        ("cull_primitive_ext", CullPrimitiveEXT),
127        ("frag_size_ext", FragSizeEXT),
128        ("frag_invocation_count_ext", FragInvocationCountEXT),
129        ("launch_id", BuiltIn::LaunchIdKHR),
130        ("launch_size", BuiltIn::LaunchSizeKHR),
131        ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
132        ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
133        ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
134        ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
135        ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
136        ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
137        ("ray_tmin", BuiltIn::RayTminKHR),
138        ("ray_tmax", BuiltIn::RayTmaxKHR),
139        ("object_to_world", BuiltIn::ObjectToWorldKHR),
140        ("world_to_object", BuiltIn::WorldToObjectKHR),
141        (
142            "hit_triangle_vertex_positions",
143            BuiltIn::HitTriangleVertexPositionsKHR,
144        ),
145        ("hit_kind", BuiltIn::HitKindKHR),
146        ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
147        ("warps_per_sm_nv", WarpsPerSMNV),
148        ("sm_count_nv", SMCountNV),
149        ("warp_id_nv", WarpIDNV),
150        ("SMIDNV", SMIDNV),
151    ]
152};
153
154const STORAGE_CLASSES: &[(&str, StorageClass)] = {
155    use StorageClass::*;
156    &[
157        ("uniform_constant", UniformConstant),
158        ("input", Input),
159        ("uniform", Uniform),
160        ("output", Output),
161        ("workgroup", Workgroup),
162        ("cross_workgroup", CrossWorkgroup),
163        ("private", Private),
164        ("function", Function),
165        ("generic", Generic),
166        ("push_constant", PushConstant),
167        ("atomic_counter", AtomicCounter),
168        ("image", Image),
169        ("storage_buffer", StorageBuffer),
170        ("callable_data", StorageClass::CallableDataKHR),
171        (
172            "incoming_callable_data",
173            StorageClass::IncomingCallableDataKHR,
174        ),
175        ("ray_payload", StorageClass::RayPayloadKHR),
176        ("hit_attribute", StorageClass::HitAttributeKHR),
177        ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
178        ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
179        ("physical_storage_buffer", PhysicalStorageBuffer),
180        ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
181    ]
182};
183
184const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
185    use ExecutionModel::*;
186    &[
187        ("vertex", Vertex),
188        ("tessellation_control", TessellationControl),
189        ("tessellation_evaluation", TessellationEvaluation),
190        ("geometry", Geometry),
191        ("fragment", Fragment),
192        ("compute", GLCompute),
193        ("task_nv", TaskNV),
194        ("mesh_nv", MeshNV),
195        ("task_ext", TaskEXT),
196        ("mesh_ext", MeshEXT),
197        ("ray_generation", ExecutionModel::RayGenerationKHR),
198        ("intersection", ExecutionModel::IntersectionKHR),
199        ("any_hit", ExecutionModel::AnyHitKHR),
200        ("closest_hit", ExecutionModel::ClosestHitKHR),
201        ("miss", ExecutionModel::MissKHR),
202        ("callable", ExecutionModel::CallableKHR),
203    ]
204};
205
206#[derive(Copy, Clone, Debug)]
207pub enum ExecutionModeExtraDim {
208    None,
209    Value,
210    X,
211    Y,
212    Z,
213    Tuple,
214}
215
216const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
217    use ExecutionMode::*;
218    use ExecutionModeExtraDim::*;
219    &[
220        ("invocations", Invocations, Value),
221        ("spacing_equal", SpacingEqual, None),
222        ("spacing_fraction_even", SpacingFractionalEven, None),
223        ("spacing_fraction_odd", SpacingFractionalOdd, None),
224        ("vertex_order_cw", VertexOrderCw, None),
225        ("vertex_order_ccw", VertexOrderCcw, None),
226        ("pixel_center_integer", PixelCenterInteger, None),
227        ("origin_upper_left", OriginUpperLeft, None),
228        ("origin_lower_left", OriginLowerLeft, None),
229        ("early_fragment_tests", EarlyFragmentTests, None),
230        ("point_mode", PointMode, None),
231        ("xfb", Xfb, None),
232        ("depth_replacing", DepthReplacing, None),
233        ("depth_greater", DepthGreater, None),
234        ("depth_less", DepthLess, None),
235        ("depth_unchanged", DepthUnchanged, None),
236        ("threads", LocalSize, Tuple),
237        ("local_size_hint_x", LocalSizeHint, X),
238        ("local_size_hint_y", LocalSizeHint, Y),
239        ("local_size_hint_z", LocalSizeHint, Z),
240        ("input_points", InputPoints, None),
241        ("input_lines", InputLines, None),
242        ("input_lines_adjacency", InputLinesAdjacency, None),
243        ("triangles", Triangles, None),
244        ("input_triangles_adjacency", InputTrianglesAdjacency, None),
245        ("quads", Quads, None),
246        ("isolines", Isolines, None),
247        ("output_vertices", OutputVertices, Value),
248        ("output_points", OutputPoints, None),
249        ("output_line_strip", OutputLineStrip, None),
250        ("output_triangle_strip", OutputTriangleStrip, None),
251        ("vec_type_hint", VecTypeHint, Value),
252        ("contraction_off", ContractionOff, None),
253        ("initializer", Initializer, None),
254        ("finalizer", Finalizer, None),
255        ("subgroup_size", SubgroupSize, Value),
256        ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
257        ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
258        ("local_size_id_x", LocalSizeId, X),
259        ("local_size_id_y", LocalSizeId, Y),
260        ("local_size_id_z", LocalSizeId, Z),
261        ("local_size_hint_id", LocalSizeHintId, Value),
262        ("post_depth_coverage", PostDepthCoverage, None),
263        ("denorm_preserve", DenormPreserve, None),
264        ("denorm_flush_to_zero", DenormFlushToZero, Value),
265        (
266            "signed_zero_inf_nan_preserve",
267            SignedZeroInfNanPreserve,
268            Value,
269        ),
270        ("rounding_mode_rte", RoundingModeRTE, Value),
271        ("rounding_mode_rtz", RoundingModeRTZ, Value),
272        ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
273        ("output_lines_nv", OutputLinesNV, None),
274        ("output_primitives_nv", OutputPrimitivesNV, Value),
275        ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
276        ("output_triangles_nv", OutputTrianglesNV, None),
277        ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
278        (
279            "output_triangles_ext",
280            ExecutionMode::OutputTrianglesEXT,
281            None,
282        ),
283        (
284            "output_primitives_ext",
285            ExecutionMode::OutputPrimitivesEXT,
286            Value,
287        ),
288        (
289            "pixel_interlock_ordered_ext",
290            PixelInterlockOrderedEXT,
291            None,
292        ),
293        (
294            "pixel_interlock_unordered_ext",
295            PixelInterlockUnorderedEXT,
296            None,
297        ),
298        (
299            "sample_interlock_ordered_ext",
300            SampleInterlockOrderedEXT,
301            None,
302        ),
303        (
304            "sample_interlock_unordered_ext",
305            SampleInterlockUnorderedEXT,
306            None,
307        ),
308        (
309            "shading_rate_interlock_ordered_ext",
310            ShadingRateInterlockOrderedEXT,
311            None,
312        ),
313        (
314            "shading_rate_interlock_unordered_ext",
315            ShadingRateInterlockUnorderedEXT,
316            None,
317        ),
318        // Reserved
319        /*("max_workgroup_size_intel_x", MaxWorkgroupSizeINTEL, X),
320        ("max_workgroup_size_intel_y", MaxWorkgroupSizeINTEL, Y),
321        ("max_workgroup_size_intel_z", MaxWorkgroupSizeINTEL, Z),
322        ("max_work_dim_intel", MaxWorkDimINTEL, Value),
323        ("no_global_offset_intel", NoGlobalOffsetINTEL, None),
324        ("num_simd_workitems_intel", NumSIMDWorkitemsINTEL, Value),*/
325    ]
326};
327
328impl Symbols {
329    fn new() -> Self {
330        let builtins = BUILTINS
331            .iter()
332            .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
333        let storage_classes = STORAGE_CLASSES
334            .iter()
335            .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
336        let execution_models = EXECUTION_MODELS
337            .iter()
338            .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
339        let custom_attributes = [
340            (
341                "sampler",
342                SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
343            ),
344            (
345                "generic_image_type",
346                SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
347            ),
348            (
349                "acceleration_structure",
350                SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
351            ),
352            (
353                "ray_query",
354                SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
355            ),
356            ("block", SpirvAttribute::Block),
357            ("flat", SpirvAttribute::Flat),
358            ("invariant", SpirvAttribute::Invariant),
359            ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
360            (
361                "sampled_image",
362                SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
363            ),
364            (
365                "runtime_array",
366                SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
367            ),
368            (
369                "typed_buffer",
370                SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
371            ),
372            (
373                "matrix",
374                SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
375            ),
376            (
377                "vector",
378                SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379            ),
380            ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
381            (
382                "buffer_store_intrinsic",
383                SpirvAttribute::BufferStoreIntrinsic,
384            ),
385        ]
386        .iter()
387        .cloned();
388        let attributes_iter = builtins
389            .chain(storage_classes)
390            .chain(execution_models)
391            .chain(custom_attributes)
392            .map(|(a, b)| (Symbol::intern(a), b));
393        let mut attributes = FxHashMap::default();
394        for (a, b) in attributes_iter {
395            let old = attributes.insert(a, b);
396            // `.collect()` into a FxHashMap does not error on duplicates, so manually write out the
397            // loop here to error on duplicates.
398            assert!(old.is_none());
399        }
400        let mut execution_modes = FxHashMap::default();
401        for &(key, mode, dim) in EXECUTION_MODES {
402            let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
403            assert!(old.is_none());
404        }
405
406        let mut libm_intrinsics = FxHashMap::default();
407        for &(a, b) in libm_intrinsics::TABLE {
408            let old = libm_intrinsics.insert(Symbol::intern(a), b);
409            assert!(old.is_none());
410        }
411        Self {
412            discriminant: Symbol::intern("discriminant"),
413            rust_gpu: Symbol::intern("rust_gpu"),
414            spirv_attr_with_version: Symbol::intern(&spirv_attr_with_version()),
415            vector: Symbol::intern("vector"),
416            v1: Symbol::intern("v1"),
417            libm: Symbol::intern("libm"),
418            entry_point_name: Symbol::intern("entry_point_name"),
419            spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
420
421            descriptor_set: Symbol::intern("descriptor_set"),
422            binding: Symbol::intern("binding"),
423            input_attachment_index: Symbol::intern("input_attachment_index"),
424
425            spec_constant: Symbol::intern("spec_constant"),
426            id: Symbol::intern("id"),
427            default: Symbol::intern("default"),
428
429            attributes,
430            execution_modes,
431            libm_intrinsics,
432        }
433    }
434
435    /// Obtain an `Rc` handle to the current thread's `Symbols` instance, which
436    /// will be shared between all `Symbols::get()` calls on the same thread.
437    ///
438    /// While this is relatively cheap, prefer caching it in e.g. `CodegenCx`,
439    /// rather than calling `get()` every time a field of `Symbols` is needed.
440    pub fn get() -> Rc<Self> {
441        thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
442        SYMBOLS.with(Rc::clone)
443    }
444}