diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 58827c5404b..e27f98b03ce 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -3680,6 +3680,24 @@ find_sampler_and_texture_index(struct ntv_context *ctx, struct spriv_tex_src *te } } +static SpvId +get_texture_load(struct ntv_context *ctx, SpvId sampler_id, nir_tex_instr *tex, + SpvId image_type, SpvId sampled_type) +{ + if (ctx->stage == MESA_SHADER_KERNEL) { + SpvId image_load = spirv_builder_emit_load(&ctx->builder, image_type, sampler_id); + if (nir_tex_instr_need_sampler(tex)) { + SpvId sampler_load = spirv_builder_emit_load(&ctx->builder, spirv_builder_type_sampler(&ctx->builder), + ctx->cl_samplers[tex->sampler_index]); + return spirv_builder_emit_sampled_image(&ctx->builder, sampled_type, image_load, sampler_load); + } else { + return image_load; + } + } else { + return spirv_builder_emit_load(&ctx->builder, sampled_type, sampler_id); + } +} + static void emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) { @@ -3720,18 +3738,8 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) SpvId ptr = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassUniformConstant, sampled_type); sampler_id = spirv_builder_emit_access_chain(&ctx->builder, ptr, sampler_id, &tex_src.tex_offset, 1); } - SpvId load; - if (ctx->stage == MESA_SHADER_KERNEL) { - SpvId image_load = spirv_builder_emit_load(&ctx->builder, image_type, sampler_id); - if (nir_tex_instr_need_sampler(tex)) { - SpvId sampler_load = spirv_builder_emit_load(&ctx->builder, spirv_builder_type_sampler(&ctx->builder), ctx->cl_samplers[tex->sampler_index]); - load = spirv_builder_emit_sampled_image(&ctx->builder, sampled_type, image_load, sampler_load); - } else { - load = image_load; - } - } else { - load = spirv_builder_emit_load(&ctx->builder, sampled_type, sampler_id); - } + + SpvId load = get_texture_load(ctx, sampler_id, tex, image_type, sampled_type); if (tex->is_sparse) tex->def.num_components--;