1use crate::attr::{Entry, ExecutionModeExtra, IntrinsicType, SpecConstant, SpirvAttribute};
2use crate::builder::libm_intrinsics;
3use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
4use rustc_ast::ast::{LitIntType, LitKind, MetaItemInner, MetaItemLit};
5use rustc_data_structures::fx::FxHashMap;
6use rustc_hir::Attribute;
7use rustc_span::Span;
8use rustc_span::symbol::{Ident, Symbol};
9use std::rc::Rc;
10
11pub struct Symbols {
17 pub discriminant: Symbol,
18 pub rust_gpu: Symbol,
19 pub spirv: Symbol,
20 pub libm: Symbol,
21 pub entry_point_name: Symbol,
22 pub spv_khr_vulkan_memory_model: Symbol,
23
24 descriptor_set: Symbol,
25 binding: Symbol,
26 input_attachment_index: Symbol,
27
28 spec_constant: Symbol,
29 id: Symbol,
30 default: Symbol,
31
32 attributes: FxHashMap<Symbol, SpirvAttribute>,
33 execution_modes: FxHashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
34 pub libm_intrinsics: FxHashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
35}
36
37const BUILTINS: &[(&str, BuiltIn)] = {
38 use BuiltIn::*;
39 &[
40 ("position", Position),
41 ("point_size", PointSize),
42 ("clip_distance", ClipDistance),
43 ("cull_distance", CullDistance),
44 ("vertex_id", VertexId),
45 ("instance_id", InstanceId),
46 ("primitive_id", PrimitiveId),
47 ("invocation_id", InvocationId),
48 ("layer", Layer),
49 ("viewport_index", ViewportIndex),
50 ("tess_level_outer", TessLevelOuter),
51 ("tess_level_inner", TessLevelInner),
52 ("tess_coord", TessCoord),
53 ("patch_vertices", PatchVertices),
54 ("frag_coord", FragCoord),
55 ("point_coord", PointCoord),
56 ("front_facing", FrontFacing),
57 ("sample_id", SampleId),
58 ("sample_position", SamplePosition),
59 ("sample_mask", SampleMask),
60 ("frag_depth", FragDepth),
61 ("helper_invocation", HelperInvocation),
62 ("num_workgroups", NumWorkgroups),
63 ("workgroup_id", WorkgroupId),
65 ("local_invocation_id", LocalInvocationId),
66 ("global_invocation_id", GlobalInvocationId),
67 ("local_invocation_index", LocalInvocationIndex),
68 ("subgroup_size", SubgroupSize),
74 ("num_subgroups", NumSubgroups),
76 ("subgroup_id", SubgroupId),
78 ("subgroup_local_invocation_id", SubgroupLocalInvocationId),
79 ("vertex_index", VertexIndex),
80 ("instance_index", InstanceIndex),
81 ("subgroup_eq_mask", SubgroupEqMask),
82 ("subgroup_ge_mask", SubgroupGeMask),
83 ("subgroup_gt_mask", SubgroupGtMask),
84 ("subgroup_le_mask", SubgroupLeMask),
85 ("subgroup_lt_mask", SubgroupLtMask),
86 ("base_vertex", BaseVertex),
87 ("base_instance", BaseInstance),
88 ("draw_index", DrawIndex),
89 ("device_index", DeviceIndex),
90 ("view_index", ViewIndex),
91 ("bary_coord_no_persp_amd", BaryCoordNoPerspAMD),
92 (
93 "bary_coord_no_persp_centroid_amd",
94 BaryCoordNoPerspCentroidAMD,
95 ),
96 ("bary_coord_no_persp_sample_amd", BaryCoordNoPerspSampleAMD),
97 ("bary_coord_smooth_amd", BaryCoordSmoothAMD),
98 ("bary_coord_smooth_centroid_amd", BaryCoordSmoothCentroidAMD),
99 ("bary_coord_smooth_sample_amd", BaryCoordSmoothSampleAMD),
100 ("bary_coord_pull_model_amd", BaryCoordPullModelAMD),
101 ("frag_stencil_ref_ext", FragStencilRefEXT),
102 ("viewport_mask_nv", ViewportMaskNV),
103 ("secondary_position_nv", SecondaryPositionNV),
104 ("secondary_viewport_mask_nv", SecondaryViewportMaskNV),
105 ("position_per_view_nv", PositionPerViewNV),
106 ("viewport_mask_per_view_nv", ViewportMaskPerViewNV),
107 ("fully_covered_ext", FullyCoveredEXT),
108 ("task_count_nv", TaskCountNV),
109 ("primitive_count_nv", PrimitiveCountNV),
110 ("primitive_indices_nv", PrimitiveIndicesNV),
111 ("clip_distance_per_view_nv", ClipDistancePerViewNV),
112 ("cull_distance_per_view_nv", CullDistancePerViewNV),
113 ("layer_per_view_nv", LayerPerViewNV),
114 ("mesh_view_count_nv", MeshViewCountNV),
115 ("mesh_view_indices_nv", MeshViewIndicesNV),
116 ("bary_coord_nv", BuiltIn::BaryCoordNV),
117 ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV),
118 ("bary_coord", BaryCoordKHR),
119 ("bary_coord_no_persp", BaryCoordNoPerspKHR),
120 ("primitive_point_indices_ext", PrimitivePointIndicesEXT),
121 ("primitive_line_indices_ext", PrimitiveLineIndicesEXT),
122 (
123 "primitive_triangle_indices_ext",
124 PrimitiveTriangleIndicesEXT,
125 ),
126 ("cull_primitive_ext", CullPrimitiveEXT),
127 ("frag_size_ext", FragSizeEXT),
128 ("frag_invocation_count_ext", FragInvocationCountEXT),
129 ("launch_id", BuiltIn::LaunchIdKHR),
130 ("launch_size", BuiltIn::LaunchSizeKHR),
131 ("instance_custom_index", BuiltIn::InstanceCustomIndexKHR),
132 ("ray_geometry_index", BuiltIn::RayGeometryIndexKHR),
133 ("world_ray_origin", BuiltIn::WorldRayOriginKHR),
134 ("world_ray_direction", BuiltIn::WorldRayDirectionKHR),
135 ("object_ray_origin", BuiltIn::ObjectRayOriginKHR),
136 ("object_ray_direction", BuiltIn::ObjectRayDirectionKHR),
137 ("ray_tmin", BuiltIn::RayTminKHR),
138 ("ray_tmax", BuiltIn::RayTmaxKHR),
139 ("object_to_world", BuiltIn::ObjectToWorldKHR),
140 ("world_to_object", BuiltIn::WorldToObjectKHR),
141 ("hit_kind", BuiltIn::HitKindKHR),
142 ("incoming_ray_flags", BuiltIn::IncomingRayFlagsKHR),
143 ("warps_per_sm_nv", WarpsPerSMNV),
144 ("sm_count_nv", SMCountNV),
145 ("warp_id_nv", WarpIDNV),
146 ("SMIDNV", SMIDNV),
147 ]
148};
149
150const STORAGE_CLASSES: &[(&str, StorageClass)] = {
151 use StorageClass::*;
152 &[
153 ("uniform_constant", UniformConstant),
154 ("input", Input),
155 ("uniform", Uniform),
156 ("output", Output),
157 ("workgroup", Workgroup),
158 ("cross_workgroup", CrossWorkgroup),
159 ("private", Private),
160 ("function", Function),
161 ("generic", Generic),
162 ("push_constant", PushConstant),
163 ("atomic_counter", AtomicCounter),
164 ("image", Image),
165 ("storage_buffer", StorageBuffer),
166 ("callable_data", StorageClass::CallableDataKHR),
167 (
168 "incoming_callable_data",
169 StorageClass::IncomingCallableDataKHR,
170 ),
171 ("ray_payload", StorageClass::RayPayloadKHR),
172 ("hit_attribute", StorageClass::HitAttributeKHR),
173 ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR),
174 ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR),
175 ("physical_storage_buffer", PhysicalStorageBuffer),
176 ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT),
177 ]
178};
179
180const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
181 use ExecutionModel::*;
182 &[
183 ("vertex", Vertex),
184 ("tessellation_control", TessellationControl),
185 ("tessellation_evaluation", TessellationEvaluation),
186 ("geometry", Geometry),
187 ("fragment", Fragment),
188 ("compute", GLCompute),
189 ("task_nv", TaskNV),
190 ("mesh_nv", MeshNV),
191 ("task_ext", TaskEXT),
192 ("mesh_ext", MeshEXT),
193 ("ray_generation", ExecutionModel::RayGenerationKHR),
194 ("intersection", ExecutionModel::IntersectionKHR),
195 ("any_hit", ExecutionModel::AnyHitKHR),
196 ("closest_hit", ExecutionModel::ClosestHitKHR),
197 ("miss", ExecutionModel::MissKHR),
198 ("callable", ExecutionModel::CallableKHR),
199 ]
200};
201
202#[derive(Copy, Clone, Debug)]
203enum ExecutionModeExtraDim {
204 None,
205 Value,
206 X,
207 Y,
208 Z,
209 Tuple,
210}
211
212const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
213 use ExecutionMode::*;
214 use ExecutionModeExtraDim::*;
215 &[
216 ("invocations", Invocations, Value),
217 ("spacing_equal", SpacingEqual, None),
218 ("spacing_fraction_even", SpacingFractionalEven, None),
219 ("spacing_fraction_odd", SpacingFractionalOdd, None),
220 ("vertex_order_cw", VertexOrderCw, None),
221 ("vertex_order_ccw", VertexOrderCcw, None),
222 ("pixel_center_integer", PixelCenterInteger, None),
223 ("origin_upper_left", OriginUpperLeft, None),
224 ("origin_lower_left", OriginLowerLeft, None),
225 ("early_fragment_tests", EarlyFragmentTests, None),
226 ("point_mode", PointMode, None),
227 ("xfb", Xfb, None),
228 ("depth_replacing", DepthReplacing, None),
229 ("depth_greater", DepthGreater, None),
230 ("depth_less", DepthLess, None),
231 ("depth_unchanged", DepthUnchanged, None),
232 ("threads", LocalSize, Tuple),
233 ("local_size_hint_x", LocalSizeHint, X),
234 ("local_size_hint_y", LocalSizeHint, Y),
235 ("local_size_hint_z", LocalSizeHint, Z),
236 ("input_points", InputPoints, None),
237 ("input_lines", InputLines, None),
238 ("input_lines_adjacency", InputLinesAdjacency, None),
239 ("triangles", Triangles, None),
240 ("input_triangles_adjacency", InputTrianglesAdjacency, None),
241 ("quads", Quads, None),
242 ("isolines", Isolines, None),
243 ("output_vertices", OutputVertices, Value),
244 ("output_points", OutputPoints, None),
245 ("output_line_strip", OutputLineStrip, None),
246 ("output_triangle_strip", OutputTriangleStrip, None),
247 ("vec_type_hint", VecTypeHint, Value),
248 ("contraction_off", ContractionOff, None),
249 ("initializer", Initializer, None),
250 ("finalizer", Finalizer, None),
251 ("subgroup_size", SubgroupSize, Value),
252 ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value),
253 ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value),
254 ("local_size_id_x", LocalSizeId, X),
255 ("local_size_id_y", LocalSizeId, Y),
256 ("local_size_id_z", LocalSizeId, Z),
257 ("local_size_hint_id", LocalSizeHintId, Value),
258 ("post_depth_coverage", PostDepthCoverage, None),
259 ("denorm_preserve", DenormPreserve, None),
260 ("denorm_flush_to_zero", DenormFlushToZero, Value),
261 (
262 "signed_zero_inf_nan_preserve",
263 SignedZeroInfNanPreserve,
264 Value,
265 ),
266 ("rounding_mode_rte", RoundingModeRTE, Value),
267 ("rounding_mode_rtz", RoundingModeRTZ, Value),
268 ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None),
269 ("output_lines_nv", OutputLinesNV, None),
270 ("output_primitives_nv", OutputPrimitivesNV, Value),
271 ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None),
272 ("output_triangles_nv", OutputTrianglesNV, None),
273 ("output_lines_ext", ExecutionMode::OutputLinesEXT, None),
274 (
275 "output_triangles_ext",
276 ExecutionMode::OutputTrianglesEXT,
277 None,
278 ),
279 (
280 "output_primitives_ext",
281 ExecutionMode::OutputPrimitivesEXT,
282 Value,
283 ),
284 (
285 "pixel_interlock_ordered_ext",
286 PixelInterlockOrderedEXT,
287 None,
288 ),
289 (
290 "pixel_interlock_unordered_ext",
291 PixelInterlockUnorderedEXT,
292 None,
293 ),
294 (
295 "sample_interlock_ordered_ext",
296 SampleInterlockOrderedEXT,
297 None,
298 ),
299 (
300 "sample_interlock_unordered_ext",
301 SampleInterlockUnorderedEXT,
302 None,
303 ),
304 (
305 "shading_rate_interlock_ordered_ext",
306 ShadingRateInterlockOrderedEXT,
307 None,
308 ),
309 (
310 "shading_rate_interlock_unordered_ext",
311 ShadingRateInterlockUnorderedEXT,
312 None,
313 ),
314 ]
322};
323
324impl Symbols {
325 fn new() -> Self {
326 let builtins = BUILTINS
327 .iter()
328 .map(|&(a, b)| (a, SpirvAttribute::Builtin(b)));
329 let storage_classes = STORAGE_CLASSES
330 .iter()
331 .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b)));
332 let execution_models = EXECUTION_MODELS
333 .iter()
334 .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into())));
335 let custom_attributes = [
336 (
337 "sampler",
338 SpirvAttribute::IntrinsicType(IntrinsicType::Sampler),
339 ),
340 (
341 "generic_image_type",
342 SpirvAttribute::IntrinsicType(IntrinsicType::GenericImageType),
343 ),
344 (
345 "acceleration_structure",
346 SpirvAttribute::IntrinsicType(IntrinsicType::AccelerationStructureKhr),
347 ),
348 (
349 "ray_query",
350 SpirvAttribute::IntrinsicType(IntrinsicType::RayQueryKhr),
351 ),
352 ("block", SpirvAttribute::Block),
353 ("flat", SpirvAttribute::Flat),
354 ("invariant", SpirvAttribute::Invariant),
355 ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
356 (
357 "sampled_image",
358 SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
359 ),
360 (
361 "runtime_array",
362 SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
363 ),
364 (
365 "typed_buffer",
366 SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
367 ),
368 (
369 "matrix",
370 SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
371 ),
372 ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
373 (
374 "buffer_store_intrinsic",
375 SpirvAttribute::BufferStoreIntrinsic,
376 ),
377 ]
378 .iter()
379 .cloned();
380 let attributes_iter = builtins
381 .chain(storage_classes)
382 .chain(execution_models)
383 .chain(custom_attributes)
384 .map(|(a, b)| (Symbol::intern(a), b));
385 let mut attributes = FxHashMap::default();
386 for (a, b) in attributes_iter {
387 let old = attributes.insert(a, b);
388 assert!(old.is_none());
391 }
392 let mut execution_modes = FxHashMap::default();
393 for &(key, mode, dim) in EXECUTION_MODES {
394 let old = execution_modes.insert(Symbol::intern(key), (mode, dim));
395 assert!(old.is_none());
396 }
397
398 let mut libm_intrinsics = FxHashMap::default();
399 for &(a, b) in libm_intrinsics::TABLE {
400 let old = libm_intrinsics.insert(Symbol::intern(a), b);
401 assert!(old.is_none());
402 }
403 Self {
404 discriminant: Symbol::intern("discriminant"),
405 rust_gpu: Symbol::intern("rust_gpu"),
406 spirv: Symbol::intern("spirv"),
407 libm: Symbol::intern("libm"),
408 entry_point_name: Symbol::intern("entry_point_name"),
409 spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
410
411 descriptor_set: Symbol::intern("descriptor_set"),
412 binding: Symbol::intern("binding"),
413 input_attachment_index: Symbol::intern("input_attachment_index"),
414
415 spec_constant: Symbol::intern("spec_constant"),
416 id: Symbol::intern("id"),
417 default: Symbol::intern("default"),
418
419 attributes,
420 execution_modes,
421 libm_intrinsics,
422 }
423 }
424
425 pub fn get() -> Rc<Self> {
431 thread_local!(static SYMBOLS: Rc<Symbols> = Rc::new(Symbols::new()));
432 SYMBOLS.with(Rc::clone)
433 }
434}
435
436type ParseAttrError = (Span, String);
438
439pub(crate) fn parse_attrs_for_checking<'a>(
441 sym: &'a Symbols,
442 attrs: &'a [Attribute],
443) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
444 attrs.iter().flat_map(move |attr| {
445 let (whole_attr_error, args) = match attr {
446 Attribute::Unparsed(item) => {
447 let s = &item.path.segments;
449 if s.len() > 1 && s[0].name == sym.rust_gpu {
450 if s.len() != 2 || s[1].name != sym.spirv {
452 (
454 Some(Err((
455 attr.span(),
456 "unknown `rust_gpu` attribute, expected `rust_gpu::spirv`"
457 .to_string(),
458 ))),
459 Default::default(),
460 )
461 } else if let Some(args) = attr.meta_item_list() {
462 (None, args)
464 } else {
465 (
467 Some(Err((
468 attr.span(),
469 "#[rust_gpu::spirv(..)] attribute must have at least one argument"
470 .to_string(),
471 ))),
472 Default::default(),
473 )
474 }
475 } else {
476 (None, Default::default())
478 }
479 }
480 Attribute::Parsed(_) => (None, Default::default()),
481 };
482
483 whole_attr_error
484 .into_iter()
485 .chain(args.into_iter().map(move |ref arg| {
486 let span = arg.span();
487 let parsed_attr = if arg.has_name(sym.descriptor_set) {
488 SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
489 } else if arg.has_name(sym.binding) {
490 SpirvAttribute::Binding(parse_attr_int_value(arg)?)
491 } else if arg.has_name(sym.input_attachment_index) {
492 SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
493 } else if arg.has_name(sym.spec_constant) {
494 SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
495 } else {
496 let name = match arg.ident() {
497 Some(i) => i,
498 None => {
499 return Err((
500 span,
501 "#[spirv(..)] attribute argument must be single identifier"
502 .to_string(),
503 ));
504 }
505 };
506 sym.attributes.get(&name.name).map_or_else(
507 || Err((name.span, "unknown argument to spirv attribute".to_string())),
508 |a| {
509 Ok(match a {
510 SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
511 parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
512 ),
513 _ => a.clone(),
514 })
515 },
516 )?
517 };
518 Ok((span, parsed_attr))
519 }))
520 })
521}
522
523fn parse_spec_constant_attr(
524 sym: &Symbols,
525 arg: &MetaItemInner,
526) -> Result<SpecConstant, ParseAttrError> {
527 let mut id = None;
528 let mut default = None;
529
530 if let Some(attrs) = arg.meta_item_list() {
531 for attr in attrs {
532 if attr.has_name(sym.id) {
533 if id.is_none() {
534 id = Some(parse_attr_int_value(attr)?);
535 } else {
536 return Err((attr.span(), "`id` may only be specified once".into()));
537 }
538 } else if attr.has_name(sym.default) {
539 if default.is_none() {
540 default = Some(parse_attr_int_value(attr)?);
541 } else {
542 return Err((attr.span(), "`default` may only be specified once".into()));
543 }
544 } else {
545 return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
546 }
547 }
548 }
549 Ok(SpecConstant {
550 id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
551 default,
552 })
553}
554
555fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
556 let arg = match arg.meta_item() {
557 Some(arg) => arg,
558 None => return Err((arg.span(), "attribute must have value".to_string())),
559 };
560 match arg.name_value_literal() {
561 Some(&MetaItemLit {
562 kind: LitKind::Int(x, LitIntType::Unsuffixed),
563 ..
564 }) if x <= u32::MAX as u128 => Ok(x.get() as u32),
565 _ => Err((arg.span, "attribute value must be integer".to_string())),
566 }
567}
568
569fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
570 let arg = match arg.meta_item() {
571 Some(arg) => arg,
572 None => return Err((arg.span(), "attribute must have value".to_string())),
573 };
574 match arg.meta_item_list() {
575 Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
576 let mut local_size = [1; 3];
577 for (idx, lit) in tuple.iter().enumerate() {
578 match lit {
579 MetaItemInner::Lit(MetaItemLit {
580 kind: LitKind::Int(x, LitIntType::Unsuffixed),
581 ..
582 }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
583 _ => return Err((lit.span(), "must be a u32 literal".to_string())),
584 }
585 }
586 Ok(local_size)
587 }
588 Some([]) => Err((
589 arg.span,
590 "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
591 )),
592 Some(tuple) if tuple.len() > 3 => Err((
593 arg.span,
594 "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
595 )),
596 _ => Err((
597 arg.span,
598 "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
599 )),
600 }
601}
602
603fn parse_entry_attrs(
608 sym: &Symbols,
609 arg: &MetaItemInner,
610 name: &Ident,
611 execution_model: ExecutionModel,
612) -> Result<Entry, ParseAttrError> {
613 use ExecutionMode::*;
614 use ExecutionModel::*;
615 let mut entry = Entry::from(execution_model);
616 let mut origin_mode: Option<ExecutionMode> = None;
617 let mut local_size: Option<[u32; 3]> = None;
618 let mut local_size_hint: Option<[u32; 3]> = None;
619 if let Some(attrs) = arg.meta_item_list() {
622 for attr in attrs {
623 if let Some(attr_name) = attr.ident() {
624 if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
625 {
626 use ExecutionModeExtraDim::*;
627 let val = match extra_dim {
628 None | Tuple => Option::None,
629 _ => Some(parse_attr_int_value(attr)?),
630 };
631 match execution_mode {
632 OriginUpperLeft | OriginLowerLeft => {
633 origin_mode.replace(*execution_mode);
634 }
635 LocalSize => {
636 if local_size.is_none() {
637 local_size.replace(parse_local_size_attr(attr)?);
638 } else {
639 return Err((
640 attr_name.span,
641 String::from(
642 "`#[spirv(compute(threads))]` may only be specified once",
643 ),
644 ));
645 }
646 }
647 LocalSizeHint => {
648 let val = val.unwrap();
649 if local_size_hint.is_none() {
650 local_size_hint.replace([1, 1, 1]);
651 }
652 let local_size_hint = local_size_hint.as_mut().unwrap();
653 match extra_dim {
654 X => {
655 local_size_hint[0] = val;
656 }
657 Y => {
658 local_size_hint[1] = val;
659 }
660 Z => {
661 local_size_hint[2] = val;
662 }
663 _ => unreachable!(),
664 }
665 }
666 _ => {
688 if let Some(val) = val {
689 entry
690 .execution_modes
691 .push((*execution_mode, ExecutionModeExtra::new([val])));
692 } else {
693 entry
694 .execution_modes
695 .push((*execution_mode, ExecutionModeExtra::new([])));
696 }
697 }
698 }
699 } else if attr_name.name == sym.entry_point_name {
700 match attr.value_str() {
701 Some(sym) => {
702 entry.name = Some(sym);
703 }
704 None => {
705 return Err((
706 attr_name.span,
707 format!(
708 "#[spirv({name}(..))] unknown attribute argument {attr_name}"
709 ),
710 ));
711 }
712 }
713 } else {
714 return Err((
715 attr_name.span,
716 format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
717 ));
718 }
719 } else {
720 return Err((
721 arg.span(),
722 format!("#[spirv({name}(..))] attribute argument must be single identifier"),
723 ));
724 }
725 }
726 }
727 match entry.execution_model {
728 Fragment => {
729 let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
730 entry
731 .execution_modes
732 .push((origin_mode, ExecutionModeExtra::new([])));
733 }
734 GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
735 if let Some(local_size) = local_size {
736 entry
737 .execution_modes
738 .push((LocalSize, ExecutionModeExtra::new(local_size)));
739 } else {
740 return Err((
741 arg.span(),
742 String::from(
743 "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
744 ),
745 ));
746 }
747 }
748 _ => {}
750 }
751 Ok(entry)
752}