diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index efc94cbec7f..598e7edf42d 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -3765,8 +3765,7 @@ radv_pipeline_emit_hw_ngg(const struct radv_device *device, struct radeon_cmdbuf const struct radv_physical_device *pdevice = device->physical_device; uint64_t va = radv_shader_get_va(shader); gl_shader_stage es_type = - radv_pipeline_has_stage(pipeline, MESA_SHADER_MESH) ? MESA_SHADER_MESH : - radv_pipeline_has_stage(pipeline, MESA_SHADER_TESS_CTRL) ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX; + shader->info.stage == MESA_SHADER_GEOMETRY ? shader->info.gs.es_type : shader->info.stage; struct radv_shader *es = pipeline->base.shaders[es_type]; const struct gfx10_ngg_info *ngg_state = &shader->info.ngg_info; @@ -3789,7 +3788,7 @@ radv_pipeline_emit_hw_ngg(const struct radv_device *device, struct radeon_cmdbuf unsigned ge_cntl; if (es_type == MESA_SHADER_TESS_EVAL) { - struct radv_shader *gs = pipeline->base.shaders[MESA_SHADER_GEOMETRY]; + const struct radv_shader *gs = shader->info.stage == MESA_SHADER_GEOMETRY ? shader : NULL; if (es_enable_prim_id || (gs && gs->info.uses_prim_id)) break_wave_at_eoi = true; @@ -3840,8 +3839,8 @@ radv_pipeline_emit_hw_ngg(const struct radv_device *device, struct radeon_cmdbuf S_028A84_NGG_DISABLE_PROVOK_REUSE(outinfo->export_prim_id)); /* NGG specific registers. */ - struct radv_shader *gs = pipeline->base.shaders[MESA_SHADER_GEOMETRY]; - uint32_t gs_num_invocations = gs ? gs->info.gs.invocations : 1; + uint32_t gs_num_invocations = + shader->info.stage == MESA_SHADER_GEOMETRY ? shader->info.gs.invocations : 1; if (pdevice->rad_info.gfx_level < GFX11) { radeon_set_context_reg(