1use crate::attr::{IntrinsicType, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_data_structures::fx::FxHashMap;
5use rustc_span::symbol::Symbol;
6use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
7use std::rc::Rc;
8
9pub struct Symbols {
15 pub discriminant: Symbol,
16 pub rust_gpu: Symbol,
17 pub spirv_attr_with_version: Symbol,
18 pub vector: Symbol,
19 pub v1: Symbol,
20 pub libm: Symbol,
21 pub num_traits: Symbol,
22 pub entry_point_name: Symbol,
23 pub spv_khr_vulkan_memory_model: Symbol,
24
25 pub descriptor_set: Symbol,
26 pub binding: Symbol,
27 pub location: Symbol,
28 pub input_attachment_index: Symbol,
29
30 pub spec_constant: Symbol,
31 pub id: Symbol,
32 pub default: Symbol,
33
34 pub attributes: FxHashMap<Symbol, SpirvAttribute>,
35 pub execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
36 pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
37 pub num_traits_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
38}
39
40const BUILTINS: &[(&str, BuiltIn)] = {
41 use BuiltIn::*;
42 &[
43 ("position", Position),
44 ("point_size", PointSize),
45 ("clip_distance", ClipDistance),
46 ("cull_distance", CullDistance),
47 ("vertex_id", VertexId),
48 ("instance_id", InstanceId),
49 ("primitive_id", PrimitiveId),
50 ("invocation_id", InvocationId),
51 ("layer", Layer),
52 ("viewport_index", ViewportIndex),
53 ("tess_level_outer", TessLevelOuter),
54 ("tess_level_inner", TessLevelInner),
55 ("tess_coord", TessCoord),
56 ("patch_vertices", PatchVertices),
57 ("frag_coord", FragCoord),
58 ("point_coord", PointCoord),
59 ("front_facing", FrontFacing),
60 ("sample_id", SampleId),
61 ("sample_position", SamplePosition),
62 ("sample_mask", SampleMask),
63 ("frag_depth", FragDepth),
64 ("helper_invocation", HelperInvocation),
65 ("num_workgroups", NumWorkgroups),
66 ("workgroup_id", WorkgroupId),
68 ("local_invocation_id", LocalInvocationId),
69 ("global_invocation_id", GlobalInvocationId),
70 ("local_invocation_index", LocalInvocationIndex),
71 ("subgroup_size", SubgroupSize),
77 ("num_subgroups", NumSubgroups),
79 ("subgroup_id", SubgroupId),
81 ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
82 ("vertex_index", VertexIndex),
83 ("instance_index", InstanceIndex),
84 ("subgroup_eq_mask", SubgroupEqMask),
85 ("subgroup_ge_mask", SubgroupGeMask),
86 ("subgroup_gt_mask", SubgroupGtMask),
87 ("subgroup_le_mask", SubgroupLeMask),
88 ("subgroup_lt_mask", SubgroupLtMask),
89 ("base_vertex", BaseVertex),
90 ("base_instance", BaseInstance),
91 ("draw_index", DrawIndex),
92 ("device_index", DeviceIndex),
93 ("view_index", ViewIndex),
94 ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
95 (
96 "bary_coord_no_persp_centroid_amd",
97 BaryCoordNoPerspCentroidAMD,
98 ),
99 ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
100 ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
101 ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
102 ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
103 ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
104 ("frag_stencil_ref_ext", FragStencilRefEXT),
105 ("viewport_mask_nv", ViewportMaskNV),
106 ("secondary_position_nv", SecondaryPositionNV),
107 ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
108 ("position_per_view_nv", PositionPerViewNV),
109 ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
110 ("fully_covered_ext", FullyCoveredEXT),
111 ("task_count_nv", TaskCountNV),
112 ("primitive_count_nv", PrimitiveCountNV),
113 ("primitive_indices_nv", PrimitiveIndicesNV),
114 ("clip_distance_per_view_nv", ClipDistancePerViewNV),
115 ("cull_distance_per_view_nv", CullDistancePerViewNV),
116 ("layer_per_view_nv", LayerPerViewNV),
117 ("mesh_view_count_nv", MeshViewCountNV),
118 ("mesh_view_indices_nv", MeshViewIndicesNV),
119 ("bary_coord_nv", BuiltIn::BaryCoordNV),
120 ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
121 ("bary_coord", BaryCoordKHR),
122 ("bary_coord_no_persp", BaryCoordNoPerspKHR),
123 ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
124 ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
125 (
126 "primitive_triangle_indices_ext",
127 PrimitiveTriangleIndicesEXT,
128 ),
129 ("cull_primitive_ext", CullPrimitiveEXT),
130 ("frag_size_ext", FragSizeEXT),
131 ("frag_invocation_count_ext", FragInvocationCountEXT),
132 ("launch_id", BuiltIn::LaunchIdKHR),
133 ("launch_size", BuiltIn::LaunchSizeKHR),
134 ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
135 ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
136 ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
137 ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
138 ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
139 ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
140 ("ray_tmin", BuiltIn::RayTminKHR),
141 ("ray_tmax", BuiltIn::RayTmaxKHR),
142 ("object_to_world", BuiltIn::ObjectToWorldKHR),
143 ("world_to_object", BuiltIn::WorldToObjectKHR),
144 (
145 "hit_triangle_vertex_positions",
146 BuiltIn::HitTriangleVertexPositionsKHR,
147 ),
148 ("hit_kind", BuiltIn::HitKindKHR),
149 ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
150 ("warps_per_sm_nv", WarpsPerSMNV),
151 ("sm_count_nv", SMCountNV),
152 ("warp_id_nv", WarpIDNV),
153 ("SMIDNV", SMIDNV),
154 ]
155};
156
157const STORAGE_CLASSES: &[(&str, StorageClass)] = {
158 use StorageClass::*;
159 &[
160 ("uniform_constant", UniformConstant),
161 ("input", Input),
162 ("uniform", Uniform),
163 ("output", Output),
164 ("workgroup", Workgroup),
165 ("cross_workgroup", CrossWorkgroup),
166 ("private", Private),
167 ("function", Function),
168 ("generic", Generic),
169 ("push_constant", PushConstant),
170 ("atomic_counter", AtomicCounter),
171 ("image", Image),
172 ("storage_buffer", StorageBuffer),
173 ("callable_data", StorageClass::CallableDataKHR),
174 (
175 "incoming_callable_data",
176 StorageClass::IncomingCallableDataKHR,
177 ),
178 ("ray_payload", StorageClass::RayPayloadKHR),
179 ("hit_attribute", StorageClass::HitAttributeKHR),
180 ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
181 ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
182 ("physical_storage_buffer", PhysicalStorageBuffer),
183 ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
184 ]
185};
186
187const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
188 use ExecutionModel::*;
189 &[
190 ("vertex", Vertex),
191 ("tessellation_control", TessellationControl),
192 ("tessellation_evaluation", TessellationEvaluation),
193 ("geometry", Geometry),
194 ("fragment", Fragment),
195 ("compute", GLCompute),
196 ("task_nv", TaskNV),
197 ("mesh_nv", MeshNV),
198 ("task_ext", TaskEXT),
199 ("mesh_ext", MeshEXT),
200 ("ray_generation", ExecutionModel::RayGenerationKHR),
201 ("intersection", ExecutionModel::IntersectionKHR),
202 ("any_hit", ExecutionModel::AnyHitKHR),
203 ("closest_hit", ExecutionModel::ClosestHitKHR),
204 ("miss", ExecutionModel::MissKHR),
205 ("callable", ExecutionModel::CallableKHR),
206 ]
207};
208
209#[derive(Copy, Clone, Debug)]
210pub enum ExecutionModeExtraDim {
211 None,
212 Value,
213 X,
214 Y,
215 Z,
216 Tuple,
217}
218
219const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
220 use ExecutionMode::*;
221 use ExecutionModeExtraDim::*;
222 &[
223 ("invocations", Invocations, Value),
224 ("spacing_equal", SpacingEqual, None),
225 ("spacing_fraction_even", SpacingFractionalEven, None),
226 ("spacing_fraction_odd", SpacingFractionalOdd, None),
227 ("vertex_order_cw", VertexOrderCw, None),
228 ("vertex_order_ccw", VertexOrderCcw, None),
229 ("pixel_center_integer", PixelCenterInteger, None),
230 ("origin_upper_left", OriginUpperLeft, None),
231 ("origin_lower_left", OriginLowerLeft, None),
232 ("early_fragment_tests", EarlyFragmentTests, None),
233 ("point_mode", PointMode, None),
234 ("xfb", Xfb, None),
235 ("depth_replacing", DepthReplacing, None),
236 ("depth_greater", DepthGreater, None),
237 ("depth_less", DepthLess, None),
238 ("depth_unchanged", DepthUnchanged, None),
239 ("threads", LocalSize, Tuple),
240 ("local_size_hint_x", LocalSizeHint, X),
241 ("local_size_hint_y", LocalSizeHint, Y),
242 ("local_size_hint_z", LocalSizeHint, Z),
243 ("input_points", InputPoints, None),
244 ("input_lines", InputLines, None),
245 ("input_lines_adjacency", InputLinesAdjacency, None),
246 ("triangles", Triangles, None),
247 ("input_triangles_adjacency", InputTrianglesAdjacency, None),
248 ("quads", Quads, None),
249 ("isolines", Isolines, None),
250 ("output_vertices", OutputVertices, Value),
251 ("output_points", OutputPoints, None),
252 ("output_line_strip", OutputLineStrip, None),
253 ("output_triangle_strip", OutputTriangleStrip, None),
254 ("vec_type_hint", VecTypeHint, Value),
255 ("contraction_off", ContractionOff, None),
256 ("initializer", Initializer, None),
257 ("finalizer", Finalizer, None),
258 ("subgroup_size", SubgroupSize, Value),
259 ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
260 ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
261 ("local_size_id_x", LocalSizeId, X),
262 ("local_size_id_y", LocalSizeId, Y),
263 ("local_size_id_z", LocalSizeId, Z),
264 ("local_size_hint_id", LocalSizeHintId, Value),
265 ("post_depth_coverage", PostDepthCoverage, None),
266 ("denorm_preserve", DenormPreserve, None),
267 ("denorm_flush_to_zero", DenormFlushToZero, Value),
268 (
269 "signed_zero_inf_nan_preserve",
270 SignedZeroInfNanPreserve,
271 Value,
272 ),
273 ("rounding_mode_rte", RoundingModeRTE, Value),
274 ("rounding_mode_rtz", RoundingModeRTZ, Value),
275 ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
276 ("output_lines_nv", OutputLinesNV, None),
277 ("output_primitives_nv", OutputPrimitivesNV, Value),
278 ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
279 ("output_triangles_nv", OutputTrianglesNV, None),
280 ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
281 (
282 "output_triangles_ext",
283 ExecutionMode::OutputTrianglesEXT,
284 None,
285 ),
286 (
287 "output_primitives_ext",
288 ExecutionMode::OutputPrimitivesEXT,
289 Value,
290 ),
291 (
292 "pixel_interlock_ordered_ext",
293 PixelInterlockOrderedEXT,
294 None,
295 ),
296 (
297 "pixel_interlock_unordered_ext",
298 PixelInterlockUnorderedEXT,
299 None,
300 ),
301 (
302 "sample_interlock_ordered_ext",
303 SampleInterlockOrderedEXT,
304 None,
305 ),
306 (
307 "sample_interlock_unordered_ext",
308 SampleInterlockUnorderedEXT,
309 None,
310 ),
311 (
312 "shading_rate_interlock_ordered_ext",
313 ShadingRateInterlockOrderedEXT,
314 None,
315 ),
316 (
317 "shading_rate_interlock_unordered_ext",
318 ShadingRateInterlockUnorderedEXT,
319 None,
320 ),
321 ]
329};
330
331impl Symbols {
332 fn new() -> Self {
333 let builtins = BUILTINS
334 .iter()
335 .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
336 let storage_classes = STORAGE_CLASSES
337 .iter()
338 .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
339 let execution_models = EXECUTION_MODELS
340 .iter()
341 .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
342 let custom_attributes = [
343 (
344 "sampler",
345 SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
346 ),
347 (
348 "generic_image_type",
349 SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
350 ),
351 (
352 "acceleration_structure",
353 SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
354 ),
355 (
356 "ray_query",
357 SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
358 ),
359 ("block", SpirvAttribute::Block),
360 ("flat", SpirvAttribute::Flat),
361 ("invariant", SpirvAttribute::Invariant),
362 ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
363 (
364 "sampled_image",
365 SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
366 ),
367 (
368 "runtime_array",
369 SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
370 ),
371 (
372 "typed_buffer",
373 SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
374 ),
375 (
376 "matrix",
377 SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
378 ),
379 (
380 "vector",
381 SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
382 ),
383 ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
384 (
385 "buffer_store_intrinsic",
386 SpirvAttribute::BufferStoreIntrinsic,
387 ),
388 ]
389 .iter()
390 .cloned();
391 let attributes_iter = builtins
392 .chain(storage_classes)
393 .chain(execution_models)
394 .chain(custom_attributes)
395 .map(|(a, b)| (Symbol::intern(a), b));
396 let mut attributes = FxHashMap::default();
397 for (a, b) in attributes_iter {
398 let old = attributes.insert(a, b);
399 assert!(old.is_none());
402 }
403 let mut execution_modes = FxHashMap::default();
404 for &(key, mode, dim) in EXECUTION_MODES {
405 let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
406 assert!(old.is_none());
407 }
408
409 let mut libm_intrinsics = FxHashMap::default();
410 for &(a, b) in libm_intrinsics::LIBM_TABLE {
411 let old = libm_intrinsics.insert(Symbol::intern(a), b);
412 assert!(old.is_none());
413 }
414
415 let mut num_traits_intrinsics = FxHashMap::default();
416 for &(a, b) in libm_intrinsics::NUM_TRAITS_TABLE {
417 let old = num_traits_intrinsics.insert(Symbol::intern(a), b);
418 assert!(old.is_none());
419 }
420
421 Self {
422 discriminant: Symbol::intern("discriminant"),
423 rust_gpu: Symbol::intern("rust_gpu"),
424 spirv_attr_with_version: Symbol::intern(&spirv_attr_with_version()),
425 vector: Symbol::intern("vector"),
426 v1: Symbol::intern("v1"),
427 libm: Symbol::intern("libm"),
428 num_traits: Symbol::intern("num_traits"),
429 entry_point_name: Symbol::intern("entry_point_name"),
430 spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
431
432 descriptor_set: Symbol::intern("descriptor_set"),
433 binding: Symbol::intern("binding"),
434 location: Symbol::intern("location"),
435 input_attachment_index: Symbol::intern("input_attachment_index"),
436
437 spec_constant: Symbol::intern("spec_constant"),
438 id: Symbol::intern("id"),
439 default: Symbol::intern("default"),
440
441 attributes,
442 execution_modes,
443 libm_intrinsics,
444 num_traits_intrinsics,
445 }
446 }
447
448 pub fn get() -> Rc<Self> {
454 thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
455 SYMBOLS.with(Rc::clone)
456 }
457}