1use crate::attr::{IntrinsicType, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_data_structures::fx::FxHashMap;
5use rustc_span::symbol::Symbol;
6use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
7use std::rc::Rc;
8
9pub struct Symbols {
15 pub discriminant: Symbol,
16 pub rust_gpu: Symbol,
17 pub spirv_attr_with_version: Symbol,
18 pub vector: Symbol,
19 pub v1: Symbol,
20 pub libm: Symbol,
21 pub entry_point_name: Symbol,
22 pub spv_khr_vulkan_memory_model: Symbol,
23
24 pub descriptor_set: Symbol,
25 pub binding: Symbol,
26 pub input_attachment_index: Symbol,
27
28 pub spec_constant: Symbol,
29 pub id: Symbol,
30 pub default: Symbol,
31
32 pub attributes: FxHashMap<Symbol, SpirvAttribute>,
33 pub execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
34 pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
35}
36
37const BUILTINS: &[(&str, BuiltIn)] = {
38 use BuiltIn::*;
39 &[
40 ("position", Position),
41 ("point_size", PointSize),
42 ("clip_distance", ClipDistance),
43 ("cull_distance", CullDistance),
44 ("vertex_id", VertexId),
45 ("instance_id", InstanceId),
46 ("primitive_id", PrimitiveId),
47 ("invocation_id", InvocationId),
48 ("layer", Layer),
49 ("viewport_index", ViewportIndex),
50 ("tess_level_outer", TessLevelOuter),
51 ("tess_level_inner", TessLevelInner),
52 ("tess_coord", TessCoord),
53 ("patch_vertices", PatchVertices),
54 ("frag_coord", FragCoord),
55 ("point_coord", PointCoord),
56 ("front_facing", FrontFacing),
57 ("sample_id", SampleId),
58 ("sample_position", SamplePosition),
59 ("sample_mask", SampleMask),
60 ("frag_depth", FragDepth),
61 ("helper_invocation", HelperInvocation),
62 ("num_workgroups", NumWorkgroups),
63 ("workgroup_id", WorkgroupId),
65 ("local_invocation_id", LocalInvocationId),
66 ("global_invocation_id", GlobalInvocationId),
67 ("local_invocation_index", LocalInvocationIndex),
68 ("subgroup_size", SubgroupSize),
74 ("num_subgroups", NumSubgroups),
76 ("subgroup_id", SubgroupId),
78 ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
79 ("vertex_index", VertexIndex),
80 ("instance_index", InstanceIndex),
81 ("subgroup_eq_mask", SubgroupEqMask),
82 ("subgroup_ge_mask", SubgroupGeMask),
83 ("subgroup_gt_mask", SubgroupGtMask),
84 ("subgroup_le_mask", SubgroupLeMask),
85 ("subgroup_lt_mask", SubgroupLtMask),
86 ("base_vertex", BaseVertex),
87 ("base_instance", BaseInstance),
88 ("draw_index", DrawIndex),
89 ("device_index", DeviceIndex),
90 ("view_index", ViewIndex),
91 ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
92 (
93 "bary_coord_no_persp_centroid_amd",
94 BaryCoordNoPerspCentroidAMD,
95 ),
96 ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
97 ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
98 ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
99 ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
100 ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
101 ("frag_stencil_ref_ext", FragStencilRefEXT),
102 ("viewport_mask_nv", ViewportMaskNV),
103 ("secondary_position_nv", SecondaryPositionNV),
104 ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
105 ("position_per_view_nv", PositionPerViewNV),
106 ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
107 ("fully_covered_ext", FullyCoveredEXT),
108 ("task_count_nv", TaskCountNV),
109 ("primitive_count_nv", PrimitiveCountNV),
110 ("primitive_indices_nv", PrimitiveIndicesNV),
111 ("clip_distance_per_view_nv", ClipDistancePerViewNV),
112 ("cull_distance_per_view_nv", CullDistancePerViewNV),
113 ("layer_per_view_nv", LayerPerViewNV),
114 ("mesh_view_count_nv", MeshViewCountNV),
115 ("mesh_view_indices_nv", MeshViewIndicesNV),
116 ("bary_coord_nv", BuiltIn::BaryCoordNV),
117 ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
118 ("bary_coord", BaryCoordKHR),
119 ("bary_coord_no_persp", BaryCoordNoPerspKHR),
120 ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
121 ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
122 (
123 "primitive_triangle_indices_ext",
124 PrimitiveTriangleIndicesEXT,
125 ),
126 ("cull_primitive_ext", CullPrimitiveEXT),
127 ("frag_size_ext", FragSizeEXT),
128 ("frag_invocation_count_ext", FragInvocationCountEXT),
129 ("launch_id", BuiltIn::LaunchIdKHR),
130 ("launch_size", BuiltIn::LaunchSizeKHR),
131 ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
132 ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
133 ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
134 ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
135 ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
136 ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
137 ("ray_tmin", BuiltIn::RayTminKHR),
138 ("ray_tmax", BuiltIn::RayTmaxKHR),
139 ("object_to_world", BuiltIn::ObjectToWorldKHR),
140 ("world_to_object", BuiltIn::WorldToObjectKHR),
141 (
142 "hit_triangle_vertex_positions",
143 BuiltIn::HitTriangleVertexPositionsKHR,
144 ),
145 ("hit_kind", BuiltIn::HitKindKHR),
146 ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
147 ("warps_per_sm_nv", WarpsPerSMNV),
148 ("sm_count_nv", SMCountNV),
149 ("warp_id_nv", WarpIDNV),
150 ("SMIDNV", SMIDNV),
151 ]
152};
153
154const STORAGE_CLASSES: &[(&str, StorageClass)] = {
155 use StorageClass::*;
156 &[
157 ("uniform_constant", UniformConstant),
158 ("input", Input),
159 ("uniform", Uniform),
160 ("output", Output),
161 ("workgroup", Workgroup),
162 ("cross_workgroup", CrossWorkgroup),
163 ("private", Private),
164 ("function", Function),
165 ("generic", Generic),
166 ("push_constant", PushConstant),
167 ("atomic_counter", AtomicCounter),
168 ("image", Image),
169 ("storage_buffer", StorageBuffer),
170 ("callable_data", StorageClass::CallableDataKHR),
171 (
172 "incoming_callable_data",
173 StorageClass::IncomingCallableDataKHR,
174 ),
175 ("ray_payload", StorageClass::RayPayloadKHR),
176 ("hit_attribute", StorageClass::HitAttributeKHR),
177 ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
178 ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
179 ("physical_storage_buffer", PhysicalStorageBuffer),
180 ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
181 ]
182};
183
184const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
185 use ExecutionModel::*;
186 &[
187 ("vertex", Vertex),
188 ("tessellation_control", TessellationControl),
189 ("tessellation_evaluation", TessellationEvaluation),
190 ("geometry", Geometry),
191 ("fragment", Fragment),
192 ("compute", GLCompute),
193 ("task_nv", TaskNV),
194 ("mesh_nv", MeshNV),
195 ("task_ext", TaskEXT),
196 ("mesh_ext", MeshEXT),
197 ("ray_generation", ExecutionModel::RayGenerationKHR),
198 ("intersection", ExecutionModel::IntersectionKHR),
199 ("any_hit", ExecutionModel::AnyHitKHR),
200 ("closest_hit", ExecutionModel::ClosestHitKHR),
201 ("miss", ExecutionModel::MissKHR),
202 ("callable", ExecutionModel::CallableKHR),
203 ]
204};
205
206#[derive(Copy, Clone, Debug)]
207pub enum ExecutionModeExtraDim {
208 None,
209 Value,
210 X,
211 Y,
212 Z,
213 Tuple,
214}
215
216const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
217 use ExecutionMode::*;
218 use ExecutionModeExtraDim::*;
219 &[
220 ("invocations", Invocations, Value),
221 ("spacing_equal", SpacingEqual, None),
222 ("spacing_fraction_even", SpacingFractionalEven, None),
223 ("spacing_fraction_odd", SpacingFractionalOdd, None),
224 ("vertex_order_cw", VertexOrderCw, None),
225 ("vertex_order_ccw", VertexOrderCcw, None),
226 ("pixel_center_integer", PixelCenterInteger, None),
227 ("origin_upper_left", OriginUpperLeft, None),
228 ("origin_lower_left", OriginLowerLeft, None),
229 ("early_fragment_tests", EarlyFragmentTests, None),
230 ("point_mode", PointMode, None),
231 ("xfb", Xfb, None),
232 ("depth_replacing", DepthReplacing, None),
233 ("depth_greater", DepthGreater, None),
234 ("depth_less", DepthLess, None),
235 ("depth_unchanged", DepthUnchanged, None),
236 ("threads", LocalSize, Tuple),
237 ("local_size_hint_x", LocalSizeHint, X),
238 ("local_size_hint_y", LocalSizeHint, Y),
239 ("local_size_hint_z", LocalSizeHint, Z),
240 ("input_points", InputPoints, None),
241 ("input_lines", InputLines, None),
242 ("input_lines_adjacency", InputLinesAdjacency, None),
243 ("triangles", Triangles, None),
244 ("input_triangles_adjacency", InputTrianglesAdjacency, None),
245 ("quads", Quads, None),
246 ("isolines", Isolines, None),
247 ("output_vertices", OutputVertices, Value),
248 ("output_points", OutputPoints, None),
249 ("output_line_strip", OutputLineStrip, None),
250 ("output_triangle_strip", OutputTriangleStrip, None),
251 ("vec_type_hint", VecTypeHint, Value),
252 ("contraction_off", ContractionOff, None),
253 ("initializer", Initializer, None),
254 ("finalizer", Finalizer, None),
255 ("subgroup_size", SubgroupSize, Value),
256 ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
257 ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
258 ("local_size_id_x", LocalSizeId, X),
259 ("local_size_id_y", LocalSizeId, Y),
260 ("local_size_id_z", LocalSizeId, Z),
261 ("local_size_hint_id", LocalSizeHintId, Value),
262 ("post_depth_coverage", PostDepthCoverage, None),
263 ("denorm_preserve", DenormPreserve, None),
264 ("denorm_flush_to_zero", DenormFlushToZero, Value),
265 (
266 "signed_zero_inf_nan_preserve",
267 SignedZeroInfNanPreserve,
268 Value,
269 ),
270 ("rounding_mode_rte", RoundingModeRTE, Value),
271 ("rounding_mode_rtz", RoundingModeRTZ, Value),
272 ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
273 ("output_lines_nv", OutputLinesNV, None),
274 ("output_primitives_nv", OutputPrimitivesNV, Value),
275 ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
276 ("output_triangles_nv", OutputTrianglesNV, None),
277 ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
278 (
279 "output_triangles_ext",
280 ExecutionMode::OutputTrianglesEXT,
281 None,
282 ),
283 (
284 "output_primitives_ext",
285 ExecutionMode::OutputPrimitivesEXT,
286 Value,
287 ),
288 (
289 "pixel_interlock_ordered_ext",
290 PixelInterlockOrderedEXT,
291 None,
292 ),
293 (
294 "pixel_interlock_unordered_ext",
295 PixelInterlockUnorderedEXT,
296 None,
297 ),
298 (
299 "sample_interlock_ordered_ext",
300 SampleInterlockOrderedEXT,
301 None,
302 ),
303 (
304 "sample_interlock_unordered_ext",
305 SampleInterlockUnorderedEXT,
306 None,
307 ),
308 (
309 "shading_rate_interlock_ordered_ext",
310 ShadingRateInterlockOrderedEXT,
311 None,
312 ),
313 (
314 "shading_rate_interlock_unordered_ext",
315 ShadingRateInterlockUnorderedEXT,
316 None,
317 ),
318 ]
326};
327
328impl Symbols {
329 fn new() -> Self {
330 let builtins = BUILTINS
331 .iter()
332 .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
333 let storage_classes = STORAGE_CLASSES
334 .iter()
335 .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
336 let execution_models = EXECUTION_MODELS
337 .iter()
338 .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
339 let custom_attributes = [
340 (
341 "sampler",
342 SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
343 ),
344 (
345 "generic_image_type",
346 SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
347 ),
348 (
349 "acceleration_structure",
350 SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
351 ),
352 (
353 "ray_query",
354 SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
355 ),
356 ("block", SpirvAttribute::Block),
357 ("flat", SpirvAttribute::Flat),
358 ("invariant", SpirvAttribute::Invariant),
359 ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
360 (
361 "sampled_image",
362 SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
363 ),
364 (
365 "runtime_array",
366 SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
367 ),
368 (
369 "typed_buffer",
370 SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
371 ),
372 (
373 "matrix",
374 SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
375 ),
376 (
377 "vector",
378 SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379 ),
380 ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
381 (
382 "buffer_store_intrinsic",
383 SpirvAttribute::BufferStoreIntrinsic,
384 ),
385 ]
386 .iter()
387 .cloned();
388 let attributes_iter = builtins
389 .chain(storage_classes)
390 .chain(execution_models)
391 .chain(custom_attributes)
392 .map(|(a, b)| (Symbol::intern(a), b));
393 let mut attributes = FxHashMap::default();
394 for (a, b) in attributes_iter {
395 let old = attributes.insert(a, b);
396 assert!(old.is_none());
399 }
400 let mut execution_modes = FxHashMap::default();
401 for &(key, mode, dim) in EXECUTION_MODES {
402 let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
403 assert!(old.is_none());
404 }
405
406 let mut libm_intrinsics = FxHashMap::default();
407 for &(a, b) in libm_intrinsics::TABLE {
408 let old = libm_intrinsics.insert(Symbol::intern(a), b);
409 assert!(old.is_none());
410 }
411 Self {
412 discriminant: Symbol::intern("discriminant"),
413 rust_gpu: Symbol::intern("rust_gpu"),
414 spirv_attr_with_version: Symbol::intern(&spirv_attr_with_version()),
415 vector: Symbol::intern("vector"),
416 v1: Symbol::intern("v1"),
417 libm: Symbol::intern("libm"),
418 entry_point_name: Symbol::intern("entry_point_name"),
419 spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
420
421 descriptor_set: Symbol::intern("descriptor_set"),
422 binding: Symbol::intern("binding"),
423 input_attachment_index: Symbol::intern("input_attachment_index"),
424
425 spec_constant: Symbol::intern("spec_constant"),
426 id: Symbol::intern("id"),
427 default: Symbol::intern("default"),
428
429 attributes,
430 execution_modes,
431 libm_intrinsics,
432 }
433 }
434
435 pub fn get() -> Rc<Self> {
441 thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
442 SYMBOLS.with(Rc::clone)
443 }
444}