rustc_codegen_spirv/
symbols.rs

1use crate::attr::{IntrinsicType, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_data_structures::fx::FxHashMap;
5use rustc_span::symbol::Symbol;
6use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
7use std::rc::Rc;
8
9/// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords.
10/// Symbols are interned, as in, they don't actually store the string itself inside them, but rather an index into a
11/// global table of strings. Then, whenever a new Symbol is created, the global table is checked to see if the string
12/// already exists, deduplicating it if so. This makes things like comparison and cloning really cheap. So, this struct
13/// is to allocate all our keywords up front and intern them all, so we can do comparisons really easily and fast.
14pub struct Symbols {
15    pub discriminant: Symbol,
16    pub rust_gpu: Symbol,
17    pub spirv_attr_with_version: Symbol,
18    pub vector: Symbol,
19    pub v1: Symbol,
20    pub libm: Symbol,
21    pub num_traits: Symbol,
22    pub entry_point_name: Symbol,
23    pub spv_khr_vulkan_memory_model: Symbol,
24
25    pub descriptor_set: Symbol,
26    pub binding: Symbol,
27    pub location: Symbol,
28    pub input_attachment_index: Symbol,
29
30    pub spec_constant: Symbol,
31    pub id: Symbol,
32    pub default: Symbol,
33
34    pub attributes: FxHashMap<Symbol, SpirvAttribute>,
35    pub execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
36    pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
37    pub num_traits_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
38}
39
40const BUILTINS: &[(&str, BuiltIn)] = {
41    use BuiltIn::*;
42    &[
43        ("position", Position),
44        ("point_size", PointSize),
45        ("clip_distance", ClipDistance),
46        ("cull_distance", CullDistance),
47        ("vertex_id", VertexId),
48        ("instance_id", InstanceId),
49        ("primitive_id", PrimitiveId),
50        ("invocation_id", InvocationId),
51        ("layer", Layer),
52        ("viewport_index", ViewportIndex),
53        ("tess_level_outer", TessLevelOuter),
54        ("tess_level_inner", TessLevelInner),
55        ("tess_coord", TessCoord),
56        ("patch_vertices", PatchVertices),
57        ("frag_coord", FragCoord),
58        ("point_coord", PointCoord),
59        ("front_facing", FrontFacing),
60        ("sample_id", SampleId),
61        ("sample_position", SamplePosition),
62        ("sample_mask", SampleMask),
63        ("frag_depth", FragDepth),
64        ("helper_invocation", HelperInvocation),
65        ("num_workgroups", NumWorkgroups),
66        // ("workgroup_size", WorkgroupSize), -- constant
67        ("workgroup_id", WorkgroupId),
68        ("local_invocation_id", LocalInvocationId),
69        ("global_invocation_id", GlobalInvocationId),
70        ("local_invocation_index", LocalInvocationIndex),
71        // ("work_dim", WorkDim), -- Kernel-only
72        // ("global_size", GlobalSize), -- Kernel-only
73        // ("enqueued_workgroup_size", EnqueuedWorkgroupSize), -- Kernel-only
74        // ("global_offset", GlobalOffset), -- Kernel-only
75        // ("global_linear_id", GlobalLinearId), -- Kernel-only
76        ("subgroup_size", SubgroupSize),
77        // ("subgroup_max_size", SubgroupMaxSize), -- Kernel-only
78        ("num_subgroups", NumSubgroups),
79        // ("num_enqueued_subgroups", NumEnqueuedSubgroups), -- Kernel-only
80        ("subgroup_id", SubgroupId),
81        ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
82        ("vertex_index", VertexIndex),
83        ("instance_index", InstanceIndex),
84        ("subgroup_eq_mask", SubgroupEqMask),
85        ("subgroup_ge_mask", SubgroupGeMask),
86        ("subgroup_gt_mask", SubgroupGtMask),
87        ("subgroup_le_mask", SubgroupLeMask),
88        ("subgroup_lt_mask", SubgroupLtMask),
89        ("base_vertex", BaseVertex),
90        ("base_instance", BaseInstance),
91        ("draw_index", DrawIndex),
92        ("device_index", DeviceIndex),
93        ("view_index", ViewIndex),
94        ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
95        (
96            "bary_coord_no_persp_centroid_amd",
97            BaryCoordNoPerspCentroidAMD,
98        ),
99        ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
100        ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
101        ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
102        ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
103        ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
104        ("frag_stencil_ref_ext", FragStencilRefEXT),
105        ("viewport_mask_nv", ViewportMaskNV),
106        ("secondary_position_nv", SecondaryPositionNV),
107        ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
108        ("position_per_view_nv", PositionPerViewNV),
109        ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
110        ("fully_covered_ext", FullyCoveredEXT),
111        ("task_count_nv", TaskCountNV),
112        ("primitive_count_nv", PrimitiveCountNV),
113        ("primitive_indices_nv", PrimitiveIndicesNV),
114        ("clip_distance_per_view_nv", ClipDistancePerViewNV),
115        ("cull_distance_per_view_nv", CullDistancePerViewNV),
116        ("layer_per_view_nv", LayerPerViewNV),
117        ("mesh_view_count_nv", MeshViewCountNV),
118        ("mesh_view_indices_nv", MeshViewIndicesNV),
119        ("bary_coord_nv", BuiltIn::BaryCoordNV),
120        ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
121        ("bary_coord", BaryCoordKHR),
122        ("bary_coord_no_persp", BaryCoordNoPerspKHR),
123        ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
124        ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
125        (
126            "primitive_triangle_indices_ext",
127            PrimitiveTriangleIndicesEXT,
128        ),
129        ("cull_primitive_ext", CullPrimitiveEXT),
130        ("frag_size_ext", FragSizeEXT),
131        ("frag_invocation_count_ext", FragInvocationCountEXT),
132        ("launch_id", BuiltIn::LaunchIdKHR),
133        ("launch_size", BuiltIn::LaunchSizeKHR),
134        ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
135        ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
136        ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
137        ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
138        ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
139        ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
140        ("ray_tmin", BuiltIn::RayTminKHR),
141        ("ray_tmax", BuiltIn::RayTmaxKHR),
142        ("object_to_world", BuiltIn::ObjectToWorldKHR),
143        ("world_to_object", BuiltIn::WorldToObjectKHR),
144        (
145            "hit_triangle_vertex_positions",
146            BuiltIn::HitTriangleVertexPositionsKHR,
147        ),
148        ("hit_kind", BuiltIn::HitKindKHR),
149        ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
150        ("warps_per_sm_nv", WarpsPerSMNV),
151        ("sm_count_nv", SMCountNV),
152        ("warp_id_nv", WarpIDNV),
153        ("SMIDNV", SMIDNV),
154    ]
155};
156
157const STORAGE_CLASSES: &[(&str, StorageClass)] = {
158    use StorageClass::*;
159    &[
160        ("uniform_constant", UniformConstant),
161        ("input", Input),
162        ("uniform", Uniform),
163        ("output", Output),
164        ("workgroup", Workgroup),
165        ("cross_workgroup", CrossWorkgroup),
166        ("private", Private),
167        ("function", Function),
168        ("generic", Generic),
169        ("push_constant", PushConstant),
170        ("atomic_counter", AtomicCounter),
171        ("image", Image),
172        ("storage_buffer", StorageBuffer),
173        ("callable_data", StorageClass::CallableDataKHR),
174        (
175            "incoming_callable_data",
176            StorageClass::IncomingCallableDataKHR,
177        ),
178        ("ray_payload", StorageClass::RayPayloadKHR),
179        ("hit_attribute", StorageClass::HitAttributeKHR),
180        ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
181        ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
182        ("physical_storage_buffer", PhysicalStorageBuffer),
183        ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
184    ]
185};
186
187const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
188    use ExecutionModel::*;
189    &[
190        ("vertex", Vertex),
191        ("tessellation_control", TessellationControl),
192        ("tessellation_evaluation", TessellationEvaluation),
193        ("geometry", Geometry),
194        ("fragment", Fragment),
195        ("compute", GLCompute),
196        ("task_nv", TaskNV),
197        ("mesh_nv", MeshNV),
198        ("task_ext", TaskEXT),
199        ("mesh_ext", MeshEXT),
200        ("ray_generation", ExecutionModel::RayGenerationKHR),
201        ("intersection", ExecutionModel::IntersectionKHR),
202        ("any_hit", ExecutionModel::AnyHitKHR),
203        ("closest_hit", ExecutionModel::ClosestHitKHR),
204        ("miss", ExecutionModel::MissKHR),
205        ("callable", ExecutionModel::CallableKHR),
206    ]
207};
208
209#[derive(Copy, Clone, Debug)]
210pub enum ExecutionModeExtraDim {
211    None,
212    Value,
213    X,
214    Y,
215    Z,
216    Tuple,
217}
218
219const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
220    use ExecutionMode::*;
221    use ExecutionModeExtraDim::*;
222    &[
223        ("invocations", Invocations, Value),
224        ("spacing_equal", SpacingEqual, None),
225        ("spacing_fraction_even", SpacingFractionalEven, None),
226        ("spacing_fraction_odd", SpacingFractionalOdd, None),
227        ("vertex_order_cw", VertexOrderCw, None),
228        ("vertex_order_ccw", VertexOrderCcw, None),
229        ("pixel_center_integer", PixelCenterInteger, None),
230        ("origin_upper_left", OriginUpperLeft, None),
231        ("origin_lower_left", OriginLowerLeft, None),
232        ("early_fragment_tests", EarlyFragmentTests, None),
233        ("point_mode", PointMode, None),
234        ("xfb", Xfb, None),
235        ("depth_replacing", DepthReplacing, None),
236        ("depth_greater", DepthGreater, None),
237        ("depth_less", DepthLess, None),
238        ("depth_unchanged", DepthUnchanged, None),
239        ("threads", LocalSize, Tuple),
240        ("local_size_hint_x", LocalSizeHint, X),
241        ("local_size_hint_y", LocalSizeHint, Y),
242        ("local_size_hint_z", LocalSizeHint, Z),
243        ("input_points", InputPoints, None),
244        ("input_lines", InputLines, None),
245        ("input_lines_adjacency", InputLinesAdjacency, None),
246        ("triangles", Triangles, None),
247        ("input_triangles_adjacency", InputTrianglesAdjacency, None),
248        ("quads", Quads, None),
249        ("isolines", Isolines, None),
250        ("output_vertices", OutputVertices, Value),
251        ("output_points", OutputPoints, None),
252        ("output_line_strip", OutputLineStrip, None),
253        ("output_triangle_strip", OutputTriangleStrip, None),
254        ("vec_type_hint", VecTypeHint, Value),
255        ("contraction_off", ContractionOff, None),
256        ("initializer", Initializer, None),
257        ("finalizer", Finalizer, None),
258        ("subgroup_size", SubgroupSize, Value),
259        ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
260        ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
261        ("local_size_id_x", LocalSizeId, X),
262        ("local_size_id_y", LocalSizeId, Y),
263        ("local_size_id_z", LocalSizeId, Z),
264        ("local_size_hint_id", LocalSizeHintId, Value),
265        ("post_depth_coverage", PostDepthCoverage, None),
266        ("denorm_preserve", DenormPreserve, None),
267        ("denorm_flush_to_zero", DenormFlushToZero, Value),
268        (
269            "signed_zero_inf_nan_preserve",
270            SignedZeroInfNanPreserve,
271            Value,
272        ),
273        ("rounding_mode_rte", RoundingModeRTE, Value),
274        ("rounding_mode_rtz", RoundingModeRTZ, Value),
275        ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
276        ("output_lines_nv", OutputLinesNV, None),
277        ("output_primitives_nv", OutputPrimitivesNV, Value),
278        ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
279        ("output_triangles_nv", OutputTrianglesNV, None),
280        ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
281        (
282            "output_triangles_ext",
283            ExecutionMode::OutputTrianglesEXT,
284            None,
285        ),
286        (
287            "output_primitives_ext",
288            ExecutionMode::OutputPrimitivesEXT,
289            Value,
290        ),
291        (
292            "pixel_interlock_ordered_ext",
293            PixelInterlockOrderedEXT,
294            None,
295        ),
296        (
297            "pixel_interlock_unordered_ext",
298            PixelInterlockUnorderedEXT,
299            None,
300        ),
301        (
302            "sample_interlock_ordered_ext",
303            SampleInterlockOrderedEXT,
304            None,
305        ),
306        (
307            "sample_interlock_unordered_ext",
308            SampleInterlockUnorderedEXT,
309            None,
310        ),
311        (
312            "shading_rate_interlock_ordered_ext",
313            ShadingRateInterlockOrderedEXT,
314            None,
315        ),
316        (
317            "shading_rate_interlock_unordered_ext",
318            ShadingRateInterlockUnorderedEXT,
319            None,
320        ),
321        // Reserved
322        /*("max_workgroup_size_intel_x", MaxWorkgroupSizeINTEL, X),
323        ("max_workgroup_size_intel_y", MaxWorkgroupSizeINTEL, Y),
324        ("max_workgroup_size_intel_z", MaxWorkgroupSizeINTEL, Z),
325        ("max_work_dim_intel", MaxWorkDimINTEL, Value),
326        ("no_global_offset_intel", NoGlobalOffsetINTEL, None),
327        ("num_simd_workitems_intel", NumSIMDWorkitemsINTEL, Value),*/
328    ]
329};
330
331impl Symbols {
332    fn new() -> Self {
333        let builtins = BUILTINS
334            .iter()
335            .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
336        let storage_classes = STORAGE_CLASSES
337            .iter()
338            .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
339        let execution_models = EXECUTION_MODELS
340            .iter()
341            .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
342        let custom_attributes = [
343            (
344                "sampler",
345                SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
346            ),
347            (
348                "generic_image_type",
349                SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
350            ),
351            (
352                "acceleration_structure",
353                SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
354            ),
355            (
356                "ray_query",
357                SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
358            ),
359            ("block", SpirvAttribute::Block),
360            ("flat", SpirvAttribute::Flat),
361            ("invariant", SpirvAttribute::Invariant),
362            ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
363            (
364                "sampled_image",
365                SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
366            ),
367            (
368                "runtime_array",
369                SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
370            ),
371            (
372                "typed_buffer",
373                SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
374            ),
375            (
376                "matrix",
377                SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
378            ),
379            (
380                "vector",
381                SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
382            ),
383            ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
384            (
385                "buffer_store_intrinsic",
386                SpirvAttribute::BufferStoreIntrinsic,
387            ),
388        ]
389        .iter()
390        .cloned();
391        let attributes_iter = builtins
392            .chain(storage_classes)
393            .chain(execution_models)
394            .chain(custom_attributes)
395            .map(|(a, b)| (Symbol::intern(a), b));
396        let mut attributes = FxHashMap::default();
397        for (a, b) in attributes_iter {
398            let old = attributes.insert(a, b);
399            // `.collect()` into a FxHashMap does not error on duplicates, so manually write out the
400            // loop here to error on duplicates.
401            assert!(old.is_none());
402        }
403        let mut execution_modes = FxHashMap::default();
404        for &(key, mode, dim) in EXECUTION_MODES {
405            let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
406            assert!(old.is_none());
407        }
408
409        let mut libm_intrinsics = FxHashMap::default();
410        for &(a, b) in libm_intrinsics::LIBM_TABLE {
411            let old = libm_intrinsics.insert(Symbol::intern(a), b);
412            assert!(old.is_none());
413        }
414
415        let mut num_traits_intrinsics = FxHashMap::default();
416        for &(a, b) in libm_intrinsics::NUM_TRAITS_TABLE {
417            let old = num_traits_intrinsics.insert(Symbol::intern(a), b);
418            assert!(old.is_none());
419        }
420
421        Self {
422            discriminant: Symbol::intern("discriminant"),
423            rust_gpu: Symbol::intern("rust_gpu"),
424            spirv_attr_with_version: Symbol::intern(&spirv_attr_with_version()),
425            vector: Symbol::intern("vector"),
426            v1: Symbol::intern("v1"),
427            libm: Symbol::intern("libm"),
428            num_traits: Symbol::intern("num_traits"),
429            entry_point_name: Symbol::intern("entry_point_name"),
430            spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
431
432            descriptor_set: Symbol::intern("descriptor_set"),
433            binding: Symbol::intern("binding"),
434            location: Symbol::intern("location"),
435            input_attachment_index: Symbol::intern("input_attachment_index"),
436
437            spec_constant: Symbol::intern("spec_constant"),
438            id: Symbol::intern("id"),
439            default: Symbol::intern("default"),
440
441            attributes,
442            execution_modes,
443            libm_intrinsics,
444            num_traits_intrinsics,
445        }
446    }
447
448    /// Obtain an `Rc` handle to the current thread's `Symbols` instance, which
449    /// will be shared between all `Symbols::get()` calls on the same thread.
450    ///
451    /// While this is relatively cheap, prefer caching it in e.g. `CodegenCx`,
452    /// rather than calling `get()` every time a field of `Symbols` is needed.
453    pub fn get() -> Rc<Self> {
454        thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
455        SYMBOLS.with(Rc::clone)
456    }
457}