diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index 95409d49f19..db252c5d803 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -2566,6 +2566,43 @@ radv_export_implicit_primitive_id(nir_shader *nir) return true; } +static void +radv_remove_point_size(const struct radv_pipeline_key *pipeline_key, + nir_shader *producer, nir_shader *consumer) +{ + if ((consumer->info.inputs_read & VARYING_BIT_PSIZ) || + !(producer->info.outputs_written & VARYING_BIT_PSIZ)) + return; + + /* Do not remove PSIZ if the shader uses XFB because it might be stored. */ + if (producer->xfb_info) + return; + + /* Do not remove PSIZ if the rasterization primitive uses points. */ + if (consumer->info.stage == MESA_SHADER_FRAGMENT && + ((producer->info.stage == MESA_SHADER_VERTEX && + pipeline_key->vs.topology == V_008958_DI_PT_POINTLIST) || + (producer->info.stage == MESA_SHADER_TESS_EVAL && producer->info.tess.point_mode) || + (producer->info.stage == MESA_SHADER_GEOMETRY && + producer->info.gs.output_primitive == SHADER_PRIM_POINTS) || + (producer->info.stage == MESA_SHADER_MESH && + producer->info.mesh.primitive_type == SHADER_PRIM_POINTS))) + return; + + nir_variable *var = + nir_find_variable_with_location(producer, nir_var_shader_out, VARYING_SLOT_PSIZ); + assert(var); + + /* Change PSIZ to a global variable which allows it to be DCE'd. */ + var->data.location = 0; + var->data.mode = nir_var_shader_temp; + + producer->info.outputs_written &= ~VARYING_BIT_PSIZ; + NIR_PASS_V(producer, nir_fixup_deref_modes); + NIR_PASS(_, producer, nir_remove_dead_variables, nir_var_shader_temp, NULL); + NIR_PASS(_, producer, nir_opt_dce); +} + static void radv_link_shaders(struct radv_pipeline *pipeline, const struct radv_pipeline_key *pipeline_key, @@ -2686,9 +2723,6 @@ radv_link_shaders(struct radv_pipeline *pipeline, } if (!pipeline_key->optimisations_disabled) { - bool uses_xfb = last_vgt_api_stage != -1 && - stages[last_vgt_api_stage].nir->xfb_info; - for (unsigned i = 0; i < shader_count; ++i) { shader_info *info = &ordered_shaders[i]->info; @@ -2726,32 +2760,8 @@ radv_link_shaders(struct radv_pipeline *pipeline, /* Remove PSIZ from shaders when it's not needed. * This is typically produced by translation layers like Zink or D9VK. */ - if (uses_xfb || !(info->outputs_written & VARYING_BIT_PSIZ)) - continue; - - bool next_stage_needs_psiz = - i != 0 && /* ordered_shaders is backwards, so next stage is: i - 1 */ - ordered_shaders[i - 1]->info.inputs_read & VARYING_BIT_PSIZ; - bool topology_uses_psiz = - info->stage == last_vgt_api_stage && - ((info->stage == MESA_SHADER_VERTEX && pipeline_key->vs.topology == V_008958_DI_PT_POINTLIST) || - (info->stage == MESA_SHADER_TESS_EVAL && info->tess.point_mode) || - (info->stage == MESA_SHADER_GEOMETRY && info->gs.output_primitive == SHADER_PRIM_POINTS) || - (info->stage == MESA_SHADER_MESH && info->mesh.primitive_type == SHADER_PRIM_POINTS)); - - nir_variable *psiz_var = - nir_find_variable_with_location(ordered_shaders[i], nir_var_shader_out, VARYING_SLOT_PSIZ); - - if (!next_stage_needs_psiz && !topology_uses_psiz && psiz_var) { - /* Change PSIZ to a global variable which allows it to be DCE'd. */ - psiz_var->data.location = 0; - psiz_var->data.mode = nir_var_shader_temp; - - info->outputs_written &= ~VARYING_BIT_PSIZ; - NIR_PASS_V(ordered_shaders[i], nir_fixup_deref_modes); - NIR_PASS(_, ordered_shaders[i], nir_remove_dead_variables, nir_var_shader_temp, NULL); - NIR_PASS(_, ordered_shaders[i], nir_opt_dce); - } + if (i != 0) + radv_remove_point_size(pipeline_key, ordered_shaders[i], ordered_shaders[i - 1]); } }