From 9728a9075c9a4b4c67dd7d63df439c1be4140f5d Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Thu, 14 Nov 2024 13:14:10 +0100 Subject: [PATCH] vtn: handle struct kernel arguments passed by value Due to LLVM ABI reasons the SPIRV-LLVM-Translator always uses pointers to private memory for struct function parameters. This includes kernel entry points. However technically it's also legal to pass those parameters by value according to the OpenCL SPIR-V Env spec. One compiler making use of this is e.g. artic based on Thorin. Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/12149 Cc: mesa-stable Reviewed-by: Alyssa Rosenzweig (cherry picked from commit d0560f59cedf7ca88efe25bfbee72ff0819bcd15) Part-of: --- .pick_status.json | 2 +- src/compiler/spirv/spirv_to_nir.c | 43 ++++++++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index 673acbd9d27..dfea0f0a320 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -134,7 +134,7 @@ "description": "vtn: handle struct kernel arguments passed by value", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 3b0f74c5c8e..b4dd47ad31d 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -6767,6 +6767,27 @@ vtn_create_builder(const uint32_t *words, size_t word_count, return NULL; } +/* See glsl_type_add_to_function_params and vtn_ssa_value_add_to_call_params */ +static void +vtn_emit_kernel_entry_point_wrapper_struct_param(struct nir_builder *b, + nir_deref_instr *deref, + nir_call_instr *call, + unsigned *idx) +{ + if (glsl_type_is_vector_or_scalar(deref->type)) { + call->params[(*idx)++] = nir_src_for_ssa(nir_load_deref(b, deref)); + } else { + unsigned elems = glsl_get_length(deref->type); + for (unsigned i = 0; i < elems; i++) { + nir_deref_instr *child_deref = glsl_type_is_struct(deref->type) + ? nir_build_deref_struct(b, deref, i) + : nir_build_deref_array_imm(b, deref, i); + vtn_emit_kernel_entry_point_wrapper_struct_param(b, child_deref, call, + idx); + } + } +} + static nir_function * vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_function *entry_point) @@ -6785,7 +6806,8 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_call_instr *call = nir_call_instr_create(b->nb.shader, entry_point); - for (unsigned i = 0; i < entry_point->num_params; ++i) { + unsigned call_idx = 0; + for (unsigned i = 0; i < b->entry_point->func->type->length; ++i) { struct vtn_type *param_type = b->entry_point->func->type->params[i]; b->shader->info.cs.has_variable_shared_mem |= @@ -6826,17 +6848,30 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_variable *copy_var = nir_local_variable_create(impl, in_var->type, "copy_in"); nir_copy_var(&b->nb, copy_var, in_var); - call->params[i] = + call->params[call_idx++] = nir_src_for_ssa(&nir_build_deref_var(&b->nb, copy_var)->def); } else if (param_type->base_type == vtn_base_type_image || param_type->base_type == vtn_base_type_sampler) { /* Don't load the var, just pass a deref of it */ - call->params[i] = nir_src_for_ssa(&nir_build_deref_var(&b->nb, in_var)->def); + call->params[call_idx++] = + nir_src_for_ssa(&nir_build_deref_var(&b->nb, in_var)->def); + } else if (param_type->base_type == vtn_base_type_struct) { + /* We decompose struct and array parameters in vtn, so we'll need to + * handle it here explicitly. + * We have to keep the arguments on the actual entry point intact, + * because the runtimes rely on it to match the SPIR-V. + */ + nir_deref_instr *deref = nir_build_deref_var(&b->nb, in_var); + vtn_emit_kernel_entry_point_wrapper_struct_param(&b->nb, deref, call, + &call_idx); } else { - call->params[i] = nir_src_for_ssa(nir_load_var(&b->nb, in_var)); + call->params[call_idx++] = + nir_src_for_ssa(nir_load_var(&b->nb, in_var)); } } + assert(call_idx == entry_point->num_params); + nir_builder_instr_insert(&b->nb, &call->instr); return main_entry_point;